Question

J'essaie de comprendre le code de la régression logistique sur la documentation officielle, mais je suis mal à comprendre la logique derrière ce code:

# early-stopping parameters
patience = 5000  # look as this many examples regardless
patience_increase = 2     # wait this much longer when a new best is
                              # found
improvement_threshold = 0.995  # a relative improvement of this much is
                               # considered significant
validation_frequency = min(n_train_batches, patience/2)
                              # go through this many
                              # minibatches before checking the network
                              # on the validation set; in this case we
                              # check every epoch

best_params = None
best_validation_loss = numpy.inf
test_score = 0.
start_time = time.clock()

done_looping = False
epoch = 0
while (epoch < n_epochs) and (not done_looping):
    # Report "1" for first epoch, "n_epochs" for last epoch
    epoch = epoch + 1
    for minibatch_index in xrange(n_train_batches):

        d_loss_wrt_params = ... # compute gradient
        params -= learning_rate * d_loss_wrt_params # gradient descent

        # iteration number. We want it to start at 0.
        iter = (epoch - 1) * n_train_batches + minibatch_index
        # note that if we do `iter % validation_frequency` it will be
        # true for iter = 0 which we do not want. We want it true for
        # iter = validation_frequency - 1.
        if (iter + 1) % validation_frequency == 0:

            this_validation_loss = ... # compute zero-one loss on validation set

            if this_validation_loss < best_validation_loss:

                # improve patience if loss improvement is good enough
                if this_validation_loss < best_validation_loss * improvement_threshold:

                    patience = max(patience, iter * patience_increase)
                best_params = copy.deepcopy(params)
                best_validation_loss = this_validation_loss

        if patience <= iter:
            done_looping = True
            break

Quelqu'un pourrait-il, expliquer à moi, qu'est-ce que les variables: la patience, patience_increase, improvement_threshold, validation_frequency, iter, représentent

Qu'est-ce que cette condition?

if (iter + 1) % validation_frequency == 0:
Était-ce utile?

La solution

Patience est le nombre de lots de formation à faire avant d'arrêter. iter est le nombre de lots de formation que vous avez vu. Chaque itération, vous décidez si oui ou non votre validation est inférieur à votre meilleur précédent. L'amélioration est uniquement si une note est inférieure à improvement_threshold * validation_score.

Il semble que patience_increase est un multiplicateur. Chaque fois que vous avez un nouveau score le plus bas, vous le nombre total ou lots formation à iter*patience_increase, mais pas en dessous de la valeur actuelle de patience.

validation_frequency est juste le nombre de lots entre les temps de vérifier le score de validation.

Licencié sous: CC-BY-SA avec attribution
scroll top