Skip to content

Commit 38867c9

Browse files
committed
varlen attention tutorial
1 parent 86b1c62 commit 38867c9

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed

_static/img/varlen_diagram.png

207 KB
Loading
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Using Variable Length Attention in PyTorch\n",
8+
"\n",
9+
"## Summary\n",
10+
"\n",
11+
"In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. "
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"> **Note:** \n",
19+
"> This tutorial currently requires you to use the PyTorch nightly build.\n",
20+
"\n",
21+
"### What you will learn\n",
22+
"\n",
23+
"- Variable length attention and how it differs from `scaled_dot_product_attention`\n",
24+
"- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n",
25+
"\n",
26+
"### Prerequisites\n",
27+
"\n",
28+
"- PyTorch v2.10.0.dev or later\n",
29+
"- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. "
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Overview of Variable Length Attention \n",
37+
"\n",
38+
"In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n",
39+
"\n",
40+
"Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. Note that NestedTensor is another way to enable variable length with packed tensors (see tutorial [here](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)).\n",
41+
"\n",
42+
"However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. In the diagram below, doc 1 is 7 tokens long, doc 2 is 10 tokens long, etc. so `cu_seq_lens = [0, 7, 17, ...]`.\n",
43+
"\n",
44+
"![Padding vs Packing Diagram](../_static/img/varlen_diagram.png)\n"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {},
50+
"source": [
51+
"### Definition\n",
52+
"\n",
53+
"Below is the definition of `varlen_attn` which returns the output tensor from the attention computation. \n",
54+
"\n",
55+
"```python\n",
56+
"def varlen_attn(\n",
57+
" query: torch.Tensor,\n",
58+
" key: torch.Tensor,\n",
59+
" value: torch.Tensor,\n",
60+
" cu_seq_q: torch.Tensor,\n",
61+
" cu_seq_k: torch.Tensor,\n",
62+
" max_q: int,\n",
63+
" max_k: int,\n",
64+
" is_causal: bool = False,\n",
65+
" return_aux: AuxRequest | None = None,\n",
66+
") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
67+
"```\n",
68+
"\n",
69+
"`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True and `return_aux` specifies which auxiliary outputs to return (ie `lse`).\n",
70+
"\n",
71+
"**Note on causal masking**\n",
72+
"\n",
73+
"When `is_causal` is set to True, causal masking is applied which means that tokens can only attend to previous tokens. For bidirectional attention, set this flag to False. \n",
74+
"\n",
75+
"In torchtitan (PyTorch's pretraining framework), we set `is_causal = True` uniformly to prevent the model from cheating and artifically driving the loss down too quickly."
76+
]
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"metadata": {},
81+
"source": [
82+
"## Example \n",
83+
"\n",
84+
"Let's walk through a simple example of how we would use `varlen_attn` in the context of training a Transformer model."
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"metadata": {},
90+
"source": [
91+
"### Creating the metadata\n",
92+
"\n",
93+
"Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n",
94+
"\n",
95+
"The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents."
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"metadata": {},
102+
"outputs": [],
103+
"source": [
104+
"import torch\n",
105+
"\n",
106+
"def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n",
107+
" batch_size, seq_len = input_batch.shape\n",
108+
" device = input_batch.device\n",
109+
" cu_seqlens_list, all_seq_lengths = [], []\n",
110+
" offset = 0\n",
111+
"\n",
112+
" for b in range(batch_size):\n",
113+
" tokens = input_batch[b]\n",
114+
" eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n",
115+
"\n",
116+
" # we use the position of the eos tokens to mark the end of documents\n",
117+
" sample_cu_seqlens = torch.cat(\n",
118+
" [\n",
119+
" torch.tensor([0], dtype=torch.int32, device=device),\n",
120+
" eos_positions + 1,\n",
121+
" torch.tensor([seq_len], dtype=torch.int32, device=device),\n",
122+
" ]\n",
123+
" )\n",
124+
" sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n",
125+
"\n",
126+
" seq_lengths = torch.diff(sample_cu_seqlens)\n",
127+
" all_seq_lengths.append(seq_lengths)\n",
128+
"\n",
129+
" cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n",
130+
" cu_seqlens_list.append(cu_seqlens_adjusted)\n",
131+
"\n",
132+
" offset += seq_len\n",
133+
"\n",
134+
" packed_cu_seqlens = torch.cat(\n",
135+
" cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n",
136+
" )\n",
137+
"\n",
138+
" max_seqlen = 0\n",
139+
" if len(all_seq_lengths) > 0:\n",
140+
" all_seq_lengths = torch.cat(all_seq_lengths)\n",
141+
" max_seqlen = all_seq_lengths.max().item()\n",
142+
"\n",
143+
" return packed_cu_seqlens, max_seqlen"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"metadata": {},
149+
"source": [
150+
"### Defining the Attention Block\n",
151+
"\n",
152+
"Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n",
153+
"\n",
154+
"This function expects the `cu_seq` indices and `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n",
155+
"\n",
156+
"Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. "
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"import torch\n",
166+
"import torch.nn as nn\n",
167+
"from torch.nn.attention.varlen import varlen_attn\n",
168+
"\n",
169+
"\n",
170+
"class SimpleVarlenAttention(nn.Module):\n",
171+
" def __init__(self, embed_dim: int, num_heads: int):\n",
172+
" super().__init__()\n",
173+
" self.embed_dim = embed_dim\n",
174+
" self.num_heads = num_heads\n",
175+
" self.head_dim = embed_dim // num_heads\n",
176+
"\n",
177+
" self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n",
178+
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
179+
"\n",
180+
" def forward(\n",
181+
" self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
182+
" ) -> torch.Tensor:\n",
183+
" batch_size, seq_len, _ = x.shape\n",
184+
" x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n",
185+
"\n",
186+
" qkv = self.qkv_proj(x_packed)\n",
187+
" q, k, v = qkv.chunk(3, dim=-1)\n",
188+
"\n",
189+
" q = q.view(-1, self.num_heads, self.head_dim)\n",
190+
" k = k.view(-1, self.num_heads, self.head_dim)\n",
191+
" v = v.view(-1, self.num_heads, self.head_dim)\n",
192+
"\n",
193+
" attn_out = varlen_attn(\n",
194+
" query=q,\n",
195+
" key=k,\n",
196+
" value=v,\n",
197+
" cu_seq_q=cu_seq,\n",
198+
" cu_seq_k=cu_seq,\n",
199+
" max_q=max_len,\n",
200+
" max_k=max_len,\n",
201+
" is_causal=True,\n",
202+
" )\n",
203+
" attn_out = attn_out.view(-1, self.embed_dim)\n",
204+
" attn_out = self.out_proj(attn_out)\n",
205+
" return attn_out.view(batch_size, seq_len, self.embed_dim)"
206+
]
207+
},
208+
{
209+
"cell_type": "markdown",
210+
"metadata": {},
211+
"source": [
212+
"We can also use `torch.compile` with `varlen_attn` and define \n",
213+
"\n",
214+
"```python \n",
215+
"compiled_varlen_attn: ClassVar[Callable] = torch.compile(\n",
216+
" varlen_attn, mode=\"max-autotune-no-cudagraphs\"\n",
217+
")\n",
218+
"```\n",
219+
"\n",
220+
"We can call `compiled_varlen_attn` instead of `varlen_attn` in the Attention forward, and everything else stays the same."
221+
]
222+
},
223+
{
224+
"cell_type": "markdown",
225+
"metadata": {},
226+
"source": [
227+
"### Creating a Transformer\n",
228+
"\n",
229+
"Now, we can use this `SimpleVarlenAttention` module in a simple Transformer."
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": null,
235+
"metadata": {},
236+
"outputs": [],
237+
"source": [
238+
"class SimpleVarlenTransformer(nn.Module):\n",
239+
" \"\"\"\n",
240+
" simple 1 layer transformer with varlen attention\n",
241+
" \"\"\"\n",
242+
"\n",
243+
" def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n",
244+
" super().__init__()\n",
245+
" self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n",
246+
" self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n",
247+
" self.norm = nn.LayerNorm(embed_dim)\n",
248+
"\n",
249+
" def forward(\n",
250+
" self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
251+
" ) -> torch.Tensor:\n",
252+
" x = self.tok_embeddings(tokens)\n",
253+
" x = x + self.attention(x, cu_seq, max_len)\n",
254+
" x = self.norm(x)\n",
255+
" return x"
256+
]
257+
},
258+
{
259+
"cell_type": "markdown",
260+
"metadata": {},
261+
"source": [
262+
"### Running a Training Step\n",
263+
"\n",
264+
"Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. "
265+
]
266+
},
267+
{
268+
"cell_type": "code",
269+
"execution_count": null,
270+
"metadata": {},
271+
"outputs": [],
272+
"source": [
273+
"def main():\n",
274+
" torch.manual_seed(42)\n",
275+
"\n",
276+
" batch_size = 3\n",
277+
" seq_len = 64\n",
278+
" vocab_size = 1000\n",
279+
" embed_dim = 128\n",
280+
" num_heads = 4\n",
281+
" eos_id = 2\n",
282+
" num_docs = 3\n",
283+
" device = \"cuda\"\n",
284+
" dtype = torch.bfloat16\n",
285+
"\n",
286+
" model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n",
287+
" device=device, dtype=dtype\n",
288+
" )\n",
289+
"\n",
290+
" # create input_batch tokens\n",
291+
" input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n",
292+
"\n",
293+
" for b in range(batch_size):\n",
294+
" # getting random positions to cut the input into multiple documents\n",
295+
" doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n",
296+
" for pos in doc_positions:\n",
297+
" input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n",
298+
" input_batch[b, -1] = eos_id\n",
299+
"\n",
300+
" cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n",
301+
" print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n",
302+
"\n",
303+
" # fwd pass\n",
304+
" output = model(input_batch, cu_seq, max_len)\n",
305+
" print(f\"output shape: {output.shape}\") # (3, 64, 128)\n",
306+
"\n",
307+
" # bwd pass\n",
308+
" loss = output.mean()\n",
309+
" loss.backward()\n",
310+
"\n",
311+
"\n",
312+
"if __name__ == \"__main__\":\n",
313+
" main()"
314+
]
315+
},
316+
{
317+
"cell_type": "markdown",
318+
"metadata": {},
319+
"source": [
320+
"## Limitations \n",
321+
"\n",
322+
"Note that this op currently only works with NVIDIA CUDA on A100 machines or newer. Supported dtypes include BF16 and FP16.\n"
323+
]
324+
}
325+
],
326+
"metadata": {
327+
"orig_nbformat": 4
328+
},
329+
"nbformat": 4,
330+
"nbformat_minor": 2
331+
}

0 commit comments

Comments
 (0)