Python 3 dependencies:
- pytorch (tested with 2.3.1 and 1.11.0, other versions probably also work)
- numpy
- argparse
- matplotlib
We consider the in-context linear regression task (Garg et al., 2022; Zhang et al., 2024). The input is
where
Multi-head linear attention with the key and query matrices merged as a single matrix
Simulate a loss trajectory of
python train.py --model attnM --head 8 --init 1e-6 --showWhen trained on in-context linear regression tasks, the linear attention with merged key and query is equivalent to a 2-layer fully-connected linear networks with a set of cubic features as input
where the cubic feature is
If we train
python train.py --model mlp --cubic_feat --head 8 --init 1e-6 --showMulti-head linear attention with separate key and query, defined as
Simulate a loss trajectory of rank-one
python train.py --model attnS --head 5 --epoch 10001 --lr 0.02 --showWe vary the rank of the key and query weights and see how the loss trajectories differ from their rank-one counterpart. We set input token dimension --rank parser). The following commands generate the txt file of the loss curve. Add --show parser to display the loss curve.
R="8 4 2 1"
for r in $R; do
python train.py --model attnS --head 9 --rank "$r" --in_dim 8 --seq_len 32 --init 5e-3 --trainset_size 80000 --epoch 20001 --lr 0.02
doneTip
All the commands we provide match what we did to generate the figures in our paper. For just playing with the code, one can use a smaller training dataset, larger initialization, and shorter training epochs. The loss curves may be a little noisier but training can run faster.
The --icl parser controls the portion of training sequences with randomly sampled task vectors. Its default setting is 1, which means a purely in-context learning task. Setting it below 1 elicits in-weight learning.
C="0 0.2 0.4 0.6 0.8 1"
for c in $C; do
python train.py --model attnM --icl c --head 8 --testset 5000 --init 1e-6 --white_cov --lr 0.0005 --epoch 4001;
done@InProceedings{yedi25icl,
title = {Training Dynamics of In-Context Learning in Linear Attention},
author = {Zhang, Yedi and Singh, Aaditya K. and Latham, Peter E. and Saxe, Andrew},
booktitle = {Proceedings of the 42nd International Conference on Machine Learning},
year = {2025},
url = {https://arxiv.org/abs/2501.16265}
}