Question

I'm trying to implement a RBM and I'm testing it on MNIST dataset. However, it does not seems to converge.

I've 28x28 visible units and 100 hidden units. I'm using mini-batches of size 50. For each epoch, I traverse the whole dataset. I've a learning rate of 0.01 and a momentum of 0.5. The weights are randomly generated based on a Gaussian distribution of mean 0.0 and stdev of 0.01. The visible and hidden biases are initialized to 0. I'm using a logistic sigmoid function as activation.

After each epoch, I compute the average reconstruction error of all mini-batches, here are the errors I get:

epoch 0: Reconstruction error average: 0.0481795
epoch 1: Reconstruction error average: 0.0350295
epoch 2: Reconstruction error average: 0.0324191
epoch 3: Reconstruction error average: 0.0309714
epoch 4: Reconstruction error average: 0.0300068

I plotted the histograms of the weights to check (left to right: hiddens, weights, visibles. top: weights, bottom: updates):

Histogram of the weights after epoch 3 Histogram of the weights after epoch 3 http://baptiste-wicht.com/static/finals/histogram_epoch_3.png

Histogram of the weights after epoch 4 Histogram of the weights after epoch 4 http://baptiste-wicht.com/static/finals/histogram_epoch_4.png

but, except for the hidden biases that seem a bit weird, the remaining seems OK.

I also tried to plot the hidden weights:

Weights after epoch 3

Weights after epoch 3 http://baptiste-wicht.com/static/finals/hiddens_weights_epoch_3.png

Weights after epoch 4

Weights after epoch 4 http://baptiste-wicht.com/static/finals/hiddens_weights_epoch_4.png

(they are plotted in two colors using that function:

static_cast<size_t>(value > 0 ? (static_cast<size_t>(value * 255.0) << 8) : (static_cast<size_t>(-value * 255.)0) << 16) << " ";

)

And here, they do not make sense at all...

If I go further, the reconstruction error falls a bit more, but do no go further than 0.025. Even if I change the momentum after sometime, it goes higher and then goes down a bit but not interestingly. Moreover, the weights do no make more sense after more epochs. In most example implementations I've seen, the weights were making some sense after iterating through the complete data set two or three times.

I've also tried to reconstruct an image from the visible units, but the results seems almost random.

What could I do to check what goes wrong in my implementation ? Should the weights be within some range ? Does something seems really strange in the data ?

Complete code: https://github.com/wichtounet/dbn/blob/master/include/rbm.hpp

Was it helpful?

Solution

You are using a very small learning rate. In most NNs trained by SGD you start out with a higher learning rate and decay it over time. Search for learning rate or adaptive learning rate to find more information on that.

Second, when implementing a new algorithm I would recommend finding the paper that introduced it and reproducing their results. A good paper should include most of the settings used - or the method used to determine the settings.

If a paper is unavailable, or it was tested on a data set you don't have access to - go find a working implementation and compare the outputs when using the same settings. If the implementations are not feature compatible, turn off as many features as you can that are not shared.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top