Skip to content

Commit 84481f1

Browse files
committed
varlen attention tutorial
1 parent 86b1c62 commit 84481f1

File tree

1 file changed

+309
-0
lines changed

1 file changed

+309
-0
lines changed
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"- mention nested tensor tutorial as another way to do packing/variable length"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"# Using Variable Length Attention in PyTorch\n",
15+
"\n",
16+
"## Summary\n",
17+
"\n",
18+
"In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. "
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"> **Note:** \n",
26+
"> This tutorial currently requires you to use the PyTorch nightly build.\n",
27+
"\n",
28+
"### What you will learn\n",
29+
"\n",
30+
"- Variable length attention and how it differs from `scaled_dot_product_attention`\n",
31+
"- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n",
32+
"\n",
33+
"### Prerequisites\n",
34+
"\n",
35+
"- PyTorch v2.10.0.dev or later\n",
36+
"- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. "
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Overview of Variable Length Attention \n",
44+
"\n",
45+
"In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n",
46+
"\n",
47+
"Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. Note that NestedTensor is another way to enable variable length with packed tensors (see tutorial [here](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)).\n",
48+
"\n",
49+
"However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. For example, if doc 1 is 3 tokens long and doc 2 is 5 tokens long, then `cu_seq = [0, 3, 8]`."
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"metadata": {},
55+
"source": [
56+
"Below is the definition of `varlen_attn`. \n",
57+
"\n",
58+
"```python\n",
59+
"def varlen_attn(\n",
60+
" query: torch.Tensor,\n",
61+
" key: torch.Tensor,\n",
62+
" value: torch.Tensor,\n",
63+
" cu_seq_q: torch.Tensor,\n",
64+
" cu_seq_k: torch.Tensor,\n",
65+
" max_q: int,\n",
66+
" max_k: int,\n",
67+
" is_causal: bool = False,\n",
68+
" return_aux: AuxRequest | None = None,\n",
69+
") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
70+
"```\n",
71+
"\n",
72+
"`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True and `return_aux` specifies which auxiliary outputs to return (ie `lse`).\n",
73+
"\n",
74+
"`varlen_attn` returns the output tensor from the attention computation. \n",
75+
"\n",
76+
"Note that this op currently only works with NVIDIA CUDA on A100 machines or newer. Supported dtypes include BF16 and FP16."
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n",
84+
"\n",
85+
"The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents."
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": null,
91+
"metadata": {},
92+
"outputs": [],
93+
"source": [
94+
"import torch\n",
95+
"\n",
96+
"def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n",
97+
" batch_size, seq_len = input_batch.shape\n",
98+
" device = input_batch.device\n",
99+
" cu_seqlens_list, all_seq_lengths = [], []\n",
100+
" offset = 0\n",
101+
"\n",
102+
" for b in range(batch_size):\n",
103+
" tokens = input_batch[b]\n",
104+
" eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n",
105+
"\n",
106+
" # we use the position of the eos tokens to mark the end of documents\n",
107+
" sample_cu_seqlens = torch.cat(\n",
108+
" [\n",
109+
" torch.tensor([0], dtype=torch.int32, device=device),\n",
110+
" eos_positions + 1,\n",
111+
" torch.tensor([seq_len], dtype=torch.int32, device=device),\n",
112+
" ]\n",
113+
" )\n",
114+
" sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n",
115+
"\n",
116+
" seq_lengths = torch.diff(sample_cu_seqlens)\n",
117+
" all_seq_lengths.append(seq_lengths)\n",
118+
"\n",
119+
" cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n",
120+
" cu_seqlens_list.append(cu_seqlens_adjusted)\n",
121+
"\n",
122+
" offset += seq_len\n",
123+
"\n",
124+
" packed_cu_seqlens = torch.cat(\n",
125+
" cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n",
126+
" )\n",
127+
"\n",
128+
" max_seqlen = 0\n",
129+
" if len(all_seq_lengths) > 0:\n",
130+
" all_seq_lengths = torch.cat(all_seq_lengths)\n",
131+
" max_seqlen = all_seq_lengths.max().item()\n",
132+
"\n",
133+
" return packed_cu_seqlens, max_seqlen"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {},
139+
"source": [
140+
"Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n",
141+
"\n",
142+
"This function expects the `cu_seq` indices and `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n",
143+
"\n",
144+
"Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. "
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"import torch\n",
154+
"import torch.nn as nn\n",
155+
"from torch.nn.attention.varlen import varlen_attn\n",
156+
"\n",
157+
"\n",
158+
"class SimpleVarlenAttention(nn.Module):\n",
159+
" def __init__(self, embed_dim: int, num_heads: int):\n",
160+
" super().__init__()\n",
161+
" self.embed_dim = embed_dim\n",
162+
" self.num_heads = num_heads\n",
163+
" self.head_dim = embed_dim // num_heads\n",
164+
"\n",
165+
" self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n",
166+
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
167+
"\n",
168+
" def forward(\n",
169+
" self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
170+
" ) -> torch.Tensor:\n",
171+
" batch_size, seq_len, _ = x.shape\n",
172+
" x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n",
173+
"\n",
174+
" qkv = self.qkv_proj(x_packed)\n",
175+
" q, k, v = qkv.chunk(3, dim=-1)\n",
176+
"\n",
177+
" q = q.view(-1, self.num_heads, self.head_dim)\n",
178+
" k = k.view(-1, self.num_heads, self.head_dim)\n",
179+
" v = v.view(-1, self.num_heads, self.head_dim)\n",
180+
"\n",
181+
" attn_out = varlen_attn(\n",
182+
" query=q,\n",
183+
" key=k,\n",
184+
" value=v,\n",
185+
" cu_seq_q=cu_seq,\n",
186+
" cu_seq_k=cu_seq,\n",
187+
" max_q=max_len,\n",
188+
" max_k=max_len,\n",
189+
" is_causal=True,\n",
190+
" )\n",
191+
" attn_out = attn_out.view(-1, self.embed_dim)\n",
192+
" attn_out = self.out_proj(attn_out)\n",
193+
" return attn_out.view(batch_size, seq_len, self.embed_dim)"
194+
]
195+
},
196+
{
197+
"cell_type": "markdown",
198+
"metadata": {},
199+
"source": [
200+
"We can also use `torch.compile` with `varlen_attn` and define \n",
201+
"\n",
202+
"```python \n",
203+
"compiled_varlen_attn: ClassVar[Callable] = torch.compile(\n",
204+
" varlen_attn, mode=\"max-autotune-no-cudagraphs\"\n",
205+
")\n",
206+
"```\n",
207+
"\n",
208+
"We can call `compiled_varlen_attn` instead of `varlen_attn` in the Attention forward, and everything else stays the same."
209+
]
210+
},
211+
{
212+
"cell_type": "markdown",
213+
"metadata": {},
214+
"source": [
215+
"Now, we can use this `SimpleVarlenAttention` module in a simple Transformer."
216+
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": null,
221+
"metadata": {},
222+
"outputs": [],
223+
"source": [
224+
"class SimpleVarlenTransformer(nn.Module):\n",
225+
" \"\"\"\n",
226+
" simple 1 layer transformer with varlen attention\n",
227+
" \"\"\"\n",
228+
"\n",
229+
" def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n",
230+
" super().__init__()\n",
231+
" self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n",
232+
" self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n",
233+
" self.norm = nn.LayerNorm(embed_dim)\n",
234+
"\n",
235+
" def forward(\n",
236+
" self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
237+
" ) -> torch.Tensor:\n",
238+
" x = self.tok_embeddings(tokens)\n",
239+
" x = x + self.attention(x, cu_seq, max_len)\n",
240+
" x = self.norm(x)\n",
241+
" return x"
242+
]
243+
},
244+
{
245+
"cell_type": "markdown",
246+
"metadata": {},
247+
"source": [
248+
"Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. "
249+
]
250+
},
251+
{
252+
"cell_type": "code",
253+
"execution_count": null,
254+
"metadata": {},
255+
"outputs": [],
256+
"source": [
257+
"def main():\n",
258+
" torch.manual_seed(42)\n",
259+
"\n",
260+
" batch_size = 3\n",
261+
" seq_len = 64\n",
262+
" vocab_size = 1000\n",
263+
" embed_dim = 128\n",
264+
" num_heads = 4\n",
265+
" eos_id = 2\n",
266+
" num_docs = 3\n",
267+
" device = \"cuda\"\n",
268+
" dtype = torch.bfloat16\n",
269+
"\n",
270+
" model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n",
271+
" device=device, dtype=dtype\n",
272+
" )\n",
273+
"\n",
274+
" # create input_batch tokens\n",
275+
" input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n",
276+
"\n",
277+
" for b in range(batch_size):\n",
278+
" # getting random positions to cut the input into multiple documents\n",
279+
" doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n",
280+
" for pos in doc_positions:\n",
281+
" input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n",
282+
" input_batch[b, -1] = eos_id\n",
283+
"\n",
284+
" cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n",
285+
" print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n",
286+
"\n",
287+
" # fwd pass\n",
288+
" output = model(input_batch, cu_seq, max_len)\n",
289+
" print(f\"output shape: {output.shape}\") # (3, 64, 128)\n",
290+
"\n",
291+
" # bwd pass\n",
292+
" loss = output.mean()\n",
293+
" loss.backward()\n",
294+
"\n",
295+
" print(f\"embedding grad shape: {model.tok_embeddings.weight.grad.shape}\") # (1000, 128)\n",
296+
" print(f\"embedding grad norm: {model.tok_embeddings.weight.grad.norm().item()}\")\n",
297+
"\n",
298+
"\n",
299+
"if __name__ == \"__main__\":\n",
300+
" main()"
301+
]
302+
}
303+
],
304+
"metadata": {
305+
"orig_nbformat": 4
306+
},
307+
"nbformat": 4,
308+
"nbformat_minor": 2
309+
}

0 commit comments

Comments
 (0)