This repository includes the code used to train models, run benchmarks, and create plots for the paper Don't Pay Attention.
Note
Licenses in subdirectories take precedence over the repository license for their respective subdirectories, and licenses in individual files take precedence over the subdirectory license for their respective files.
All the model checkpoints trained during this work are available on HuggingFace. Base and instruct-tuned preview versions of Avey-1 which you can try out are available here.
This code was tested on Ubuntu 22.04, python 3.12, and A100, H100 and H200 GPUs. It is recommended to run the setup and code in a clean python environment. Please make sure CUDA toolkit is installed correctly.
To clone the repo and install the dependencies, run
git clone https://github.avey.ai/avey-dpa
cd avey-dpa
source setup.shEither login to weights & biases if you want to log training metrics:
wandb loginor disable it:
wandb disabledSet the model path for either one of the following models you want to train/test:
Warning
Running training and benchmarks will download the sample-10BT split of FineWeb dataset and model checkpoints, make sure you have enough available disk space (at least ~120GB)
The names of the model checkpoins are in the format <model name>-dpa-<model size>-<checkpoint>, where:
- Model name can be one of
avey,mamba,rwkv7, ortpp - Model size can be one of
0.1B,0.5Bor1.5B - Available checkpoints are
100BT,95BT, or90BT
All checkpoints are available under avey-ai on HuggingFace over here.
export MODEL_NAME=avey
export MODEL_PATH=avey-ai/avey1-dpa-1.5B-100BTInstall dependencies:
sh mamba/setup.shset the model name and path:
export MODEL_NAME=mamba
export MODEL_PATH=avey-ai/mamba-dpa-1.5B-100BTsource rwkv7/env.sh
export MODEL_NAME=rwkv7
export MODEL_PATH=avey-ai/rwkv7-dpa-1.5B-100BTexport MODEL_NAME=tpp
export MODEL_PATH=avey-ai/tpp-dpa-1.5B-100BTAdjust NUMBER_OF_GPUS (on a single node) and BATCH_SIZE (start at 1, increase until your GPU runs out of memory) in train.sh, and then run:
sh train.shFor standard benchmarks reported in the paper, run:
sh eval.shFor RULER S-NIAH, run:
sh eval-long.shTo plot the NIAH heatmap (figure 1 from the paper) run:
sh plot-niah.shTo plot TTFT vs context length (figure 4 from the paper) run:
Important
Make sure you've already run the setup steps for mamba (run mamba/setup.sh) and rwkv7 (source rwkv7/env.sh)
python3 plot_ttft.py