This repository contains the pytorch implementation of our paper:
The look-ahead memory allows the model to update the memory represetations with the context on its right and interpolate with the old memory for bi-directionality. As shown in the figure below, this mechanism increases model's utilization of distant contexts compared to the recurrence memory in Transformer-XL.
You need PyTorch 1.10.0 or above to run the code. You also need Apex for mixed-precision training. The script uses distributed data parallel on multiple GPUs with mixed-precision at O2 level for better efficiency.
First download the data by running getdata.sh. All the data should be processed in the data folder. For text8, run prep_text8.py for preprocessing.
Run the scripts in pytorch folder with different model configuration by specifying the --attn_type argument:
-
attn_type = 0: use rpe of transformer-xl & recurrence memory -
attn_type = 1: use rpe of Shawn (2018) & recurrence memory -
attn_type = 2: use aboslute postion encodings (mem_lenset to 0) -
attn_type = 3: use disentangled rpe & look-ahead memory (default configuration of LaMemo) -
attn_type = 31: use rpe of transformer-xl & look-ahead memory -
attn_type = 32: use disentangled rpe ver2 (learn two vectors$u_+$ and$u_-$ in eq(12)) & look-ahead memory -
attn_type = 33: use disentangled rpe & look-ahead memory with full look-ahead attention
To train LaMemo base on wikitext-103, run the following command in pytorch:
bash run_wt103_base_ddp.sh train 2 For text8 and enwik8, the training command is similar.
We provide pre-trained checkpoints of LaMemo base and Transformer-XL base on wikitext-103.
Make a new models directory and move the unzipped model files into this directory.
To evaluate LaMemo base on wikitext-103, run the following command in pytorch:
bash run_wt103_base_ddp.sh eval 1 --work_dir /path/to/model/checkpointFor text8 and enwik8, the evaluation command is similar.
We provide a script to generate text from models trained on wikitext-103.
You can provide the prompt manually by setting --manual argument to the prompt file.
The default decoding strategy uses top-p sampling with p=0.95 and temperature=1.0.
bash run_wt103_base_ddp.sh gen 1 --manual ../data/wt103-prompt2.txt \This repository is implemented based on the original transformer-xl repo, and the Nvidia's implementation which supported mixed-precision training.
@inproceedings{ji2022LaMemo,
title = "LaMemo: Language Modeling with Look-Ahead Memory",
author = "Haozhe Ji, Rongsheng Zhang, Zhenyu Yang, Zhipeng Hu, Minlie Huang",
booktitle = "2022 Annual Conference of the North American Chapter of the Association for Computational Linguistics",
year = "2022",
}
Please kindly cite our paper if you find this paper and the codes useful!
