Transformer decoder output - how is it linear?
-
11-12-2020 - |
문제
I'm not quite sure how's the decoder output is flattened into a single vector. As from my understanding, if we input the encoder with a length N sentence, it's output is N x units (e.g. N x 1000), and we input the decoder with a length M sentence, the output of the decoder will give us M x units output. M is not fixed (M should be the length of the decoder's raw input) and will change during the different steps of inference. How do we go from here to a single vector? Screen shot from "Attention is all you need"
해결책
I'm not quite sure how's the decoder output is flattened into a single vector
That's the thing. It isn't flattened into a single vector. The linear transformation is applied to all $M$ vectors in the sequence individually. These vectors have a fixed dimension, which is why it works.