質問

からの画像キャプションシステムのデモを実装しようとしています ケラスのドキュメント. 。ドキュメントから、トレーニングの部分を理解できました。

max_caption_len = 16
vocab_size = 10000

# first, let's define an image model that
# will encode pictures into 128-dimensional vectors.
# it should be initialized with pre-trained weights.
image_model = VGG-16 CNN definition
image_model.load_weights('weight_file.h5')

# next, let's define a RNN model that encodes sequences of words
# into sequences of 128-dimensional word vectors.
language_model = Sequential()
language_model.add(Embedding(vocab_size, 256, input_length=max_caption_len))
language_model.add(GRU(output_dim=128, return_sequences=True))
language_model.add(TimeDistributedDense(128))

# let's repeat the image vector to turn it into a sequence.
image_model.add(RepeatVector(max_caption_len))

# the output of both models will be tensors of shape (samples, max_caption_len, 128).
# let's concatenate these 2 vector sequences.
model = Merge([image_model, language_model], mode='concat', concat_axis=-1)
# let's encode this vector sequence into a single vector
model.add(GRU(256, 256, return_sequences=False))
# which will be used to compute a probability
# distribution over what the next word in the caption should be!
model.add(Dense(vocab_size))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

model.fit([images, partial_captions], next_words, batch_size=16, nb_epoch=100)

しかし今、私はテスト画像のキャプションを生成する方法に混乱しています。ここに入力は[画像、partial_caption]ペアです。テスト画像では、部分的なキャプションを入力する方法は?

役に立ちましたか?

解決

この例は、キャプションの次の単語を予測するために画像と部分的なキャプションを訓練します。

Input: [🐱, "<BEGIN> The cat sat on the"]
Output: "mat"

モデルは、次の単語のみのキャプションの出力全体を予測していないことに注意してください。新しいキャプションを作成するには、各単語について複数回予測する必要があります。

Input: [🐱, "<BEGIN>"] # predict "The"
Input: [🐱, "<BEGIN> The"] # predict "cat"
Input: [🐱, "<BEGIN> The cat"] # predict "sat"
...

シーケンス全体を予測するには、使用する必要があると思います TimeDistributedDense 出力層用。

Input: [🐱, "<BEGIN> The cat sat on the mat"]
Output: "The cat sat on the mat <END>"

この問題を参照してください: https://github.com/fchollet/keras/issues/1029

ライセンス: CC-BY-SA帰属
所属していません datascience.stackexchange
scroll top