Pergunta

If I do not pretrain the text generation model like BART, how to improve the result based on transformer like tensor2tensor?

What are the improvement ideas for transformer in text generation task?

Foi útil?

Solução

If you have a lot of data available to train, you should apply the techniques used in large transformer models, like GPT-2: very deep models (48 layers for the 1.5B parameters), modified initialization, pre-normalization, and reversible tokenization. You could also apply GPT-3's locally banded sparse attention patterns.

If you have very small training data, you can apply the "unwritten" aggressive techniques described in this tweet, namely data augmentation, discrete embedding dropout, normal dropout and weight decay, and lots of patient training time.

Update: I feel like the tweet thread I referred to is important, so here are the most relevant tweets:

  • How can you successfully train transformers on small datasets like PTB and WikiText-2? Are LSTMs better on small datasets? I ran 339 experiments worth 568 GPU hours and came up with some answers. I do not have time to write a blog post, so here a twitter thread instead.

  • To give a bit background: All this came about by my past frustration with replicating Transformer-XL results on PTB and having very poor results on WikiText-2 (WT2). On WT2, my best model after 200+ experiments was 90ish ppl which is far from standard LSTM baselines (65.8 ppl).

  • ...

  • The key insight is the following: In the small dataset regime, it is all about dataset augmentation. The analog in computer vision is that you get much better results, particularly on small datasets, if you do certain dataset augmentations. This also regularizes the model.

  • The most dramatic performance gain comes from discrete embedding dropout: You embed as usual, but now with a probability p you zero the entire word vector. This is akin to masked language modeling but the goal is not to predict the mask — just regular LM with uncertain context.

  • The second most important factor is regular input dropout: You take the embeddings and dropout elements with probability p. This also has a data augmentation effect very similar to dropping out random pixels for images. What is a good way to think about this? 1/2

  • Remember that we can do King-man+woman=Queen? Now imagine input dropout removes the "man" component of "King". This forces the model to distribute specific information (gender in this case) into multiple dimensions to improve generalization making it more robust. 2/2

  • Otherwise, it is a game of further regularization (more dropout + weight decay) and of patience. I can train a good model without these tricks in 15 minutes and get 97 ppl. If I apply all these dropouts the model underfits after 7h of training to 63.4 ppl (better than LSTM).

  • You can also apply these data augmentation recipes to large datasets, but nobody would like to train for months on WT-103 for a couple of ppl points. In my opinion, techniques that require so much extra compute are more harmful to the community than useful. 1/2

  • Here the code changes to the public Transformer-XL repo that my results are based on: https://github.com/TimDettmers/transformer-xl/tree/wikitext2

  • With my changes to the public Transformer-XL repo, you can run this script to get down to 63.4 ppl on WT2: https://github.com/TimDettmers/transformer-xl/blob/wikitext2/pytorch/replicate_wt2.sh

Licenciado em: CC-BY-SA com atribuição
scroll top