Question

After reading the 'Attention is all you need' article, I understand the general architecture of a transformer. However, it is unclear to me how the feed forward neural network learns.

What I learned about neural nets is that they learn based on a target variable, through back propagation according to a particular loss function.

Feed forward neural net

Looking at the architecture of a Transformer, it is unclear to me what the target variables are in these feed forward nets. Can someone explain this to me?

The transformer architecture

Was it helpful?

Solution

Let's take the common translation task which transformers can be used for as an example: If you would like to translate English to German one example of your training data could be

("the cat is black", "die Katze ist schwarz").

In this case your target is simply the German sentence "die Katze ist schwarz" (which is of course not processed as a string but using embeddings incl. positional information). This is what you calculate your loss on, run backprop on, and derive the gradients as well as weight updates from.

Accordingly, you can think of the light blue feed forward layers of a transformer

enter image description here

as a hidden layer in regular feed forward network. Just as for a regular hidden layer its parameters are updated by running backprop based on transformer $loss(output,target)$ with target being the translated sentence.

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top