Question

I'm training a binary classifier and I'd like to see Precision/Recall metrics at different thresholds.

Tensorflow 2.3 introduced tf.keras.metrics.Precision and tf.keras.metrics.Recall which take a thresholds parameter, where you can specify one or multiple thresholds for which you want the metrics computed. This all works as advertised i.e.

m = tf.keras.metrics.Precision(thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
m.update_state([0, 1, 0, 1], [0.4, 0.5, 0.3, 0.8])
m.result().numpy()

Returns the precision value at each threshold [0.5, 0.5, 0.6666667, 1., 1., 1.], as per the documentation.

However when passed as metrics to Model.compile I get a single metric regardless of how many thresholds I have.


pr_thresholds = list(np.arange(0.05, 0.95, 0.05))
model.compile(
    'adam',
    'binary_crossentropy',
    metrics=[ 
        keras.metrics.Precision(thresholds=pr_thresholds),
        keras.metrics.Recall(thresholds=pr_thresholds),
    ]
)

I get

Epoch 34/50
395/395 [==============================] - 22s 54ms/step - loss: 0.4314 - precision: 0.7886 - recall: 0.9008 - val_loss: 0.5113 - val_precision: 0.7434 - val_recall: 0.8769

What's happening here ? Does it always use the default threshold value of 0.5 in this case ?

Is there a way I can get it to display the values for multiple thresholds during training ?

Was it helpful?

Solution

You can see the metrics value for each threshold along the fitting process if you explicitely instantiate the corresponding metric class for each threshold, as follows:

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-2),
loss='categorical_crossentropy',
metrics=[metrics.Recall(thresholds=0.6), 
         metrics.Recall(thresholds=0.9)])

model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))

and as you can see in the image below, for each epoch you can see that the first recall value (with threshold 0.6) is higher than the second one (threshold 0.9) as expected:

enter image description here

And for your case, to build the list of metrics objects programatically, where you can now see 3 recalls per epoch:

thresholds = [0.6, 0.7, 0.9]
metrics_objs_list=[metrics.Recall(thresholds=thr) for thr in thresholds]

enter image description here

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top