SGD vs SGD in mini lotti
-
01-11-2019 - |
Domanda
Così ho recentemente finito un mini algoritmo di batch per una biblioteca in costruzione a Java (rete neurale artificiale Lib). Ho quindi seguito per formare la mia rete per un problema XOR in mini batch di 2 o 3, poiché entrambi ho ottenuto una precisione peggiore per quello che ho ottenuto dal realizzare 1 (che è fondamentalmente solo SGD). Ora capisco che ho bisogno di allenarlo su più epoche, ma non sto notando alcuna velocità in fase di esecuzione che da quello che leggo dovrebbe accadere. Perchè è questo?
Ecco il mio codice (Java)
public void SGD(double[][] inputs,double[][] expected_outputs,int mini_batch_size,int epochs, boolean verbose){
//Set verbose
setVerbose(verbose);
//Create training set
TrainingSet trainingSet = new TrainingSet(inputs,expected_outputs);
//Loop through Epochs
for(int i = 0; i<epochs;i++){
//Print Progress
print("\rTrained: " + i + "/" + epochs);
//Shuffle training set
trainingSet.shuffle();
//Create the mini batches
TrainingSet.Data[][] mini_batches = createMiniBatches(trainingSet,mini_batch_size);
//Loop through mini batches
for(int j = 0; j<mini_batches.length;j++){
update_mini_batch(mini_batches[j]);
}
}
//Print Progress
print("\rTrained: " + epochs + "/" + epochs);
print("\nDone!");
}
private Pair backprop(double[] inputs, double[] target_outputs){
//Create Expected output column matrix
Matrix EO = Matrix.fromArray(new double[][]{target_outputs});
//Forward Propagate inputs
feedForward(inputs);
//Get the Errors which is also the Bias Delta
Matrix[] Errors = calculateError(EO);
//Weight Delta Matrix
Matrix[] dCdW = new Matrix[Errors.length];
//Calculate the Deltas
//Calculating the first Layers Delta
dCdW[0] = Matrix.dot(Matrix.transpose(I),Errors[0]);
//Rest of network
for (int i = 1; i < Errors.length; i++) {
dCdW[i] = Matrix.dot(Matrix.transpose(H[i - 1]), Errors[i]);
}
return new Pair(dCdW,Errors);
}
private void update_mini_batch(TrainingSet.Data[] mini_batch){
//Get first deltas
Pair deltas = backprop(mini_batch[0].input,mini_batch[0].output);
//Loop through mini batch and sum the deltas
for(int i = 1; i< mini_batch.length;i++){
deltas.add(backprop(mini_batch[i].input,mini_batch[i].output));
}
//Multiply deltas by the learning rate
//and divide by the mini batch size to get
//the mean of the deltas
deltas.multiply(learningRate/mini_batch.length);
//Update Weights and Biases
for(int i= 0; i<W.length;i++){
W[i].subtract(deltas.dCdW[i]);
B[i].subtract(deltas.dCdB[i]);
}
}
Nessuna soluzione corretta
Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a datascience.stackexchange