Question

I want to use class weights for training a CNN with a imbalanced data set. The question arise if the sum of the weights of all examples have to stays the same?

My previous plan was to use the function compute_class_weight('balanced,np.unique(y_train),y_train) function from scikit-learn.

But I'm totally unsure if this is even suitable for the class weights of a CNN?

Thank you in advance for each tip

Was it helpful?

Solution

If the "cost" for experimenting is not really that big I suggest you take the time to experiment and take this as a learning opportunity and just try if it could actually work.

There are many approaches to address class imbalance and setting class weight is one of them and the easiest to implement.

  • Change loss function (for example to focal loss for binary classification with extreme imbalance)
  • Oversampling and Undersampling
  • Setting class weights
  • Use specific algorithm that are build to address this problem e.g. siamese network which is very useful when you only say have very few training sample of object of interest.
  • etc.

Specifically for your case, I can tell you the specific case that it could fail base on my experience. So basically this very likely fail when you have extreme class imbalance say like 1% positive and 99% negative. How this could fail is simply because using class weighting in this case will put very high value on the positive sample and if your model fails to detect this, the penalty is very high and hence lead to unstable training. To top it off consider a hypothetical situation your model predict the positive class correctly on epoch 10 and then it fails on epoch 11. For this case you might get a loss for example 1.3 for epoch 10 but then on epoch 11 your loss could go to say like 37.7 simply because it fail to detect said sample. This could also affect any callbacks that utilize this loss.

In summary if the situation could be as I described then don't use this otherwise just play around and find out what's best for you.

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top