ロジスティック回帰の早期停止。シーノ
-
16-10-2019 - |
質問
公式ドキュメントのロジスティック回帰のコードを理解しようとしていますが、このコードの背後にあるロジックを理解するのに苦労しています。
# early-stopping parameters
patience = 5000 # look as this many examples regardless
patience_increase = 2 # wait this much longer when a new best is
# found
improvement_threshold = 0.995 # a relative improvement of this much is
# considered significant
validation_frequency = min(n_train_batches, patience/2)
# go through this many
# minibatches before checking the network
# on the validation set; in this case we
# check every epoch
best_params = None
best_validation_loss = numpy.inf
test_score = 0.
start_time = time.clock()
done_looping = False
epoch = 0
while (epoch < n_epochs) and (not done_looping):
# Report "1" for first epoch, "n_epochs" for last epoch
epoch = epoch + 1
for minibatch_index in xrange(n_train_batches):
d_loss_wrt_params = ... # compute gradient
params -= learning_rate * d_loss_wrt_params # gradient descent
# iteration number. We want it to start at 0.
iter = (epoch - 1) * n_train_batches + minibatch_index
# note that if we do `iter % validation_frequency` it will be
# true for iter = 0 which we do not want. We want it true for
# iter = validation_frequency - 1.
if (iter + 1) % validation_frequency == 0:
this_validation_loss = ... # compute zero-one loss on validation set
if this_validation_loss < best_validation_loss:
# improve patience if loss improvement is good enough
if this_validation_loss < best_validation_loss * improvement_threshold:
patience = max(patience, iter * patience_increase)
best_params = copy.deepcopy(params)
best_validation_loss = this_validation_loss
if patience <= iter:
done_looping = True
break
誰もが私に説明できますか、変数は何を説明しますか:忍耐、忍耐_increase、改善_threshold、validation_frequency、iter、assence?
この状態は何をしますか?
if (iter + 1) % validation_frequency == 0:
解決
Patience
停止する前に、トレーニングバッチの数です。 iter
見たトレーニングバッチの数です。各反復では、検証が以前のベストよりも低いかどうかを決定します。改善は、スコアがより低い場合にのみ保存されます improvement_threshold * validation_score
.
それはそうです patience_increase
乗数です。新しいスコアが新しいスコアがあるたびに、総数またはトレーニングバッチを上げます iter*patience_increase
, 、しかし、現在の値を下回っていません patience
.
validation_frequency
検証スコアをチェックする時間間のバッチの数です。
所属していません datascience.stackexchange