I'm learning all related to data science and how to train U-Net to do semantic segmentation.

I have a U-NET with this loss function:

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(float(y_true))
    y_pred_f = K.flatten(float(y_pred))
    intersection = K.sum(y_true_f * y_pred_f)
    return (2 * intersection + 1) // (K.sum(y_true_f) + K.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

When the same data for training and validation, the model works better with binary_crossentropy than with dice_coef_loss.

With binary_crossentropy I get this output:

Epoch 1/50
  2/698 [..............................] - ETA: 53s - loss: 0.9674 - accuracy: 0.6257WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0432s vs `on_train_batch_end` time: 0.1084s). Check your callbacks.
698/698 [==============================] - 117s 168ms/step - loss: 0.0661 - accuracy: 0.9848 - val_loss: 0.0379 - val_accuracy: 0.9902
Epoch 2/50
698/698 [==============================] - 115s 165ms/step - loss: 0.0329 - accuracy: 0.9902 - val_loss: 0.0313 - val_accuracy: 0.9901
Epoch 3/50
698/698 [==============================] - 115s 165ms/step - loss: 0.0190 - accuracy: 0.9938 - val_loss: 0.0243 - val_accuracy: 0.9920
Epoch 4/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0154 - accuracy: 0.9948 - val_loss: 0.0105 - val_accuracy: 0.9963
Epoch 5/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0090 - accuracy: 0.9967 - val_loss: 0.0094 - val_accuracy: 0.9966
Epoch 6/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0083 - accuracy: 0.9970 - val_loss: 0.0143 - val_accuracy: 0.9948
Epoch 7/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0122 - accuracy: 0.9958 - val_loss: 0.0073 - val_accuracy: 0.9972
Epoch 8/50
698/698 [==============================] - 115s 165ms/step - loss: 0.0055 - accuracy: 0.9979 - val_loss: 0.0053 - val_accuracy: 0.9979
Epoch 9/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0045 - accuracy: 0.9982 - val_loss: 0.0047 - val_accuracy: 0.9982
Epoch 10/50
698/698 [==============================] - 115s 165ms/step - loss: 0.0047 - accuracy: 0.9981 - val_loss: 0.0044 - val_accuracy: 0.9982
Epoch 11/50
698/698 [==============================] - 116s 166ms/step - loss: 0.0041 - accuracy: 0.9983 - val_loss: 0.0050 - val_accuracy: 0.9980
Epoch 12/50
698/698 [==============================] - 115s 165ms/step - loss: 0.1478 - accuracy: 0.9962 - val_loss: 0.0844 - val_accuracy: 0.9849
Epoch 13/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0478 - accuracy: 0.9872 - val_loss: 0.0290 - val_accuracy: 0.9902
Epoch 14/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0218 - accuracy: 0.9924 - val_loss: 0.0167 - val_accuracy: 0.9941
Epoch 15/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0140 - accuracy: 0.9950 - val_loss: 0.0127 - val_accuracy: 0.9956
Epoch 16/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0103 - accuracy: 0.9961 - val_loss: 0.0122 - val_accuracy: 0.9956
Epoch 17/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0096 - accuracy: 0.9964 - val_loss: 0.0084 - val_accuracy: 0.9970
Epoch 18/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0086 - accuracy: 0.9967 - val_loss: 0.0074 - val_accuracy: 0.9972
Epoch 19/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0066 - accuracy: 0.9975 - val_loss: 0.0080 - val_accuracy: 0.9970
Epoch 20/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0103 - accuracy: 0.9965 - val_loss: 0.0145 - val_accuracy: 0.9951
Epoch 21/50
698/698 [==============================] - 113s 163ms/step - loss: 0.0065 - accuracy: 0.9976 - val_loss: 0.0055 - val_accuracy: 0.9979
Epoch 22/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0051 - accuracy: 0.9981 - val_loss: 0.0057 - val_accuracy: 0.9978
Epoch 23/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0058 - accuracy: 0.9977 - val_loss: 0.0051 - val_accuracy: 0.9981
Epoch 24/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0046 - accuracy: 0.9982 - val_loss: 0.0055 - val_accuracy: 0.9980
Epoch 25/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0044 - accuracy: 0.9983 - val_loss: 0.0051 - val_accuracy: 0.9981
Epoch 26/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0049 - accuracy: 0.9981 - val_loss: 0.0089 - val_accuracy: 0.9968
Epoch 27/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0045 - accuracy: 0.9982 - val_loss: 0.0043 - val_accuracy: 0.9983
Epoch 28/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0038 - accuracy: 0.9985 - val_loss: 0.0044 - val_accuracy: 0.9984
Epoch 29/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0069 - accuracy: 0.9975 - val_loss: 0.0061 - val_accuracy: 0.9978
Epoch 30/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0039 - accuracy: 0.9984 - val_loss: 0.0045 - val_accuracy: 0.9982
Epoch 31/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0033 - accuracy: 0.9986 - val_loss: 0.0038 - val_accuracy: 0.9985
Epoch 32/50
698/698 [==============================] - 112s 161ms/step - loss: 0.0032 - accuracy: 0.9987 - val_loss: 0.0041 - val_accuracy: 0.9984
Epoch 33/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0033 - accuracy: 0.9986 - val_loss: 0.0037 - val_accuracy: 0.9985
Epoch 34/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0032 - accuracy: 0.9987 - val_loss: 0.0038 - val_accuracy: 0.9985
Epoch 35/50
698/698 [==============================] - 112s 161ms/step - loss: 0.0030 - accuracy: 0.9987 - val_loss: 0.0039 - val_accuracy: 0.9985
Epoch 36/50
698/698 [==============================] - 112s 161ms/step - loss: 0.0074 - accuracy: 0.9971 - val_loss: 0.0046 - val_accuracy: 0.9982
Epoch 37/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0031 - accuracy: 0.9987 - val_loss: 0.0033 - val_accuracy: 0.9987
Epoch 38/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0027 - accuracy: 0.9989 - val_loss: 0.0032 - val_accuracy: 0.9987
Epoch 39/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0026 - accuracy: 0.9989 - val_loss: 0.0032 - val_accuracy: 0.9987
Epoch 40/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0131 - accuracy: 0.9960 - val_loss: 0.0041 - val_accuracy: 0.9984
Epoch 41/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0031 - accuracy: 0.9987 - val_loss: 0.0033 - val_accuracy: 0.9987
Epoch 42/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0025 - accuracy: 0.9989 - val_loss: 0.0032 - val_accuracy: 0.9987
Epoch 43/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0025 - accuracy: 0.9990 - val_loss: 0.0032 - val_accuracy: 0.9987
Epoch 44/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0024 - accuracy: 0.9990 - val_loss: 0.0034 - val_accuracy: 0.9986
Epoch 45/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0026 - accuracy: 0.9989 - val_loss: 0.0036 - val_accuracy: 0.9986
Epoch 46/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0025 - accuracy: 0.9989 - val_loss: 0.0031 - val_accuracy: 0.9988
Epoch 47/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0024 - accuracy: 0.9990 - val_loss: 0.0036 - val_accuracy: 0.9987
Epoch 48/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0025 - accuracy: 0.9990 - val_loss: 0.0032 - val_accuracy: 0.9987
Epoch 49/50
698/698 [==============================] - 113s 162ms/step - loss: 0.0024 - accuracy: 0.9990 - val_loss: 0.0030 - val_accuracy: 0.9988
Epoch 50/50
698/698 [==============================] - 113s 161ms/step - loss: 0.0049 - accuracy: 0.9981 - val_loss: 0.0034 - val_accuracy: 0.9987

With dice_coef_loss I get this output:

Epoch 1/50
  2/582 [..............................] - ETA: 1:36 - loss: 0.9994 - accuracy: 0.9923WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0626s vs `on_train_batch_end` time: 0.1113s). Check your callbacks.
582/582 [==============================] - 95s 163ms/step - loss: 0.9160 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 2/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8988 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 3/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9240 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 4/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9027 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 5/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8840 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 6/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8894 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 7/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9052 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 8/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8961 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 9/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9190 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 10/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9085 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 11/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9150 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 12/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9162 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 13/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9103 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 14/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9028 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 15/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8866 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 16/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9127 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 17/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9006 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 18/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8809 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 19/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9080 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 20/50
582/582 [==============================] - 93s 160ms/step - loss: 0.8952 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 21/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8952 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 22/50
582/582 [==============================] - 93s 160ms/step - loss: 0.8969 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 23/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8919 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 24/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8935 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 25/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9035 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 26/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9073 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 27/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9005 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 28/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9041 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 29/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8902 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 30/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8909 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 31/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9097 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 32/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9130 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 33/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9026 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 34/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9002 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 35/50
582/582 [==============================] - 93s 161ms/step - loss: 0.9153 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 36/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8931 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 37/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9148 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 38/50
582/582 [==============================] - 94s 161ms/step - loss: 0.9007 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 39/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8901 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 40/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8930 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 41/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8991 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 42/50
582/582 [==============================] - 94s 161ms/step - loss: 0.8946 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 43/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8978 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 44/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9179 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 45/50
582/582 [==============================] - 93s 160ms/step - loss: 0.8976 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 46/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9051 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 47/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9082 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 48/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9040 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 49/50
582/582 [==============================] - 93s 161ms/step - loss: 0.8989 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853
Epoch 50/50
582/582 [==============================] - 93s 160ms/step - loss: 0.9231 - accuracy: 0.9862 - val_loss: 0.9218 - val_accuracy: 0.9853

Any advice about why am I getting better loss values with binary cross entropy than with dice coef?

This result makes me doubt whether I have chosen the best loss function with the binary cross-entropy.

有帮助吗?

解决方案

So the question asks about why different loss function lead to different error scores.

So globally error is there to help us measure the level of discrimination between the output of the model and the actual output which we want to get. Different loss functions have different formulations of this and are thus depending on the task itself, more appropriate to some tasks than others.

Binary cross entropy is particularly helpful for binary classification tasks when we are discriminating between two classes, due to the nature of the binary cross entropy formula.

The dice coefficient looks at the level of overlap between the models output and the desired output. This is particularly useful for semantic segmentation where we can then evaluate whether the predicted mask for picking out a certain object is the same as the ground truth mask.

许可以下: CC-BY-SA归因
scroll top