Question

I am training a Siamese neural network with pairs of similar and dissimilar objects. The features of the objects are binary data on whether they contain some properties or not (2048 features per object).

I then split my dataset into training, validation and test set (60:20:20). Afterwards, I prepared the dataset myself by pairing up at random the objects accordingly yielding 50% similar and 50% dissimilar pairs and I augment the data in the training set by generating extra pairs by random (resulting in 100,000 different instances, again a balanced dataset (50:50), vs. 1,000 instances for the validation set). I then proceed to train the Siamese network and end up estimating the cosine distance between the two outputs to get a similarity metric which is compared to my label with the binary cross entropy loss function. The learning rate used is low (lr = 0.0001) and I am using the Adam optimiser. I have tried producing really small batches (batch_size = 25), adding dropout and increasing the number of instances to avoid overfitting, but the model does not seem to generalise well regardless (see picture). I was wondering if anyone could give me any hint on what is going on - and also why is it that such bumps can be appreciated during the learning process -.

enter image description here

Was it helpful?

Solution

without going into architecture that looks reasonable but always can be updated: Weights are dominantly updated for the negative pairs. why?

In the train set where you have (approx) 10 000 times more negative samples, you are training on them and letting the network learn almost only negative pairs, but then you test it on validation and you get what you would expect. No improvement on accuracy or loss on validation set. Make sure you are more balanced (down sample train or upsample valid/test)

EDIT: given new ifo here is a couple of suggestions.

Indication is overffiting obviously

  1. Early Stopping as soon as validation dataset reaches minimum STOP training. ITs overkill after that

  2. Covariate-shift make sure that patterns are similiar in all three sets (I assume they are but just in case)

  3. Reduce complexity, how deep is your architecture, maybe you are over doing it. If you took some random Siamese network architecture from online it may be too much (they used it for images for example where you need additional complexity)

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top