Skip to content

Commit 48ac2e2

Browse files
committed
varlen attention tutorial
1 parent 86b1c62 commit 48ac2e2

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Using Variable Length Attention in PyTorch\n",
8+
"\n",
9+
"## Summary\n",
10+
"\n",
11+
"In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. "
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"> **Note:** \n",
19+
"> This tutorial currently requires you to use the PyTorch nightly build.\n",
20+
"\n",
21+
"### What you will learn\n",
22+
"\n",
23+
"- Variable length attention and how it differs from `scaled_dot_product_attention`\n",
24+
"- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n",
25+
"\n",
26+
"### Prerequisites\n",
27+
"\n",
28+
"- PyTorch v2.10.0.dev or later\n",
29+
"- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. "
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Overview of Variable Length Attention \n",
37+
"\n",
38+
"In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n",
39+
"\n",
40+
"Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. Note that NestedTensor is another way to enable variable length with packed tensors (see tutorial [here](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)).\n",
41+
"\n",
42+
"However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. For example, if doc 1 is 3 tokens long and doc 2 is 5 tokens long, then `cu_seq = [0, 3, 8]`."
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"metadata": {},
48+
"source": [
49+
"Below is the definition of `varlen_attn`. \n",
50+
"\n",
51+
"```python\n",
52+
"def varlen_attn(\n",
53+
" query: torch.Tensor,\n",
54+
" key: torch.Tensor,\n",
55+
" value: torch.Tensor,\n",
56+
" cu_seq_q: torch.Tensor,\n",
57+
" cu_seq_k: torch.Tensor,\n",
58+
" max_q: int,\n",
59+
" max_k: int,\n",
60+
" is_causal: bool = False,\n",
61+
" return_aux: AuxRequest | None = None,\n",
62+
") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
63+
"```\n",
64+
"\n",
65+
"`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True and `return_aux` specifies which auxiliary outputs to return (ie `lse`).\n",
66+
"\n",
67+
"`varlen_attn` returns the output tensor from the attention computation. \n",
68+
"\n",
69+
"Note that this op currently only works with NVIDIA CUDA on A100 machines or newer. Supported dtypes include BF16 and FP16."
70+
]
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"metadata": {},
75+
"source": [
76+
"Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n",
77+
"\n",
78+
"The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents."
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": null,
84+
"metadata": {},
85+
"outputs": [],
86+
"source": [
87+
"import torch\n",
88+
"\n",
89+
"def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n",
90+
" batch_size, seq_len = input_batch.shape\n",
91+
" device = input_batch.device\n",
92+
" cu_seqlens_list, all_seq_lengths = [], []\n",
93+
" offset = 0\n",
94+
"\n",
95+
" for b in range(batch_size):\n",
96+
" tokens = input_batch[b]\n",
97+
" eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n",
98+
"\n",
99+
" # we use the position of the eos tokens to mark the end of documents\n",
100+
" sample_cu_seqlens = torch.cat(\n",
101+
" [\n",
102+
" torch.tensor([0], dtype=torch.int32, device=device),\n",
103+
" eos_positions + 1,\n",
104+
" torch.tensor([seq_len], dtype=torch.int32, device=device),\n",
105+
" ]\n",
106+
" )\n",
107+
" sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n",
108+
"\n",
109+
" seq_lengths = torch.diff(sample_cu_seqlens)\n",
110+
" all_seq_lengths.append(seq_lengths)\n",
111+
"\n",
112+
" cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n",
113+
" cu_seqlens_list.append(cu_seqlens_adjusted)\n",
114+
"\n",
115+
" offset += seq_len\n",
116+
"\n",
117+
" packed_cu_seqlens = torch.cat(\n",
118+
" cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n",
119+
" )\n",
120+
"\n",
121+
" max_seqlen = 0\n",
122+
" if len(all_seq_lengths) > 0:\n",
123+
" all_seq_lengths = torch.cat(all_seq_lengths)\n",
124+
" max_seqlen = all_seq_lengths.max().item()\n",
125+
"\n",
126+
" return packed_cu_seqlens, max_seqlen"
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"metadata": {},
132+
"source": [
133+
"Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n",
134+
"\n",
135+
"This function expects the `cu_seq` indices and `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n",
136+
"\n",
137+
"Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. "
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"import torch\n",
147+
"import torch.nn as nn\n",
148+
"from torch.nn.attention.varlen import varlen_attn\n",
149+
"\n",
150+
"\n",
151+
"class SimpleVarlenAttention(nn.Module):\n",
152+
" def __init__(self, embed_dim: int, num_heads: int):\n",
153+
" super().__init__()\n",
154+
" self.embed_dim = embed_dim\n",
155+
" self.num_heads = num_heads\n",
156+
" self.head_dim = embed_dim // num_heads\n",
157+
"\n",
158+
" self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n",
159+
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
160+
"\n",
161+
" def forward(\n",
162+
" self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
163+
" ) -> torch.Tensor:\n",
164+
" batch_size, seq_len, _ = x.shape\n",
165+
" x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n",
166+
"\n",
167+
" qkv = self.qkv_proj(x_packed)\n",
168+
" q, k, v = qkv.chunk(3, dim=-1)\n",
169+
"\n",
170+
" q = q.view(-1, self.num_heads, self.head_dim)\n",
171+
" k = k.view(-1, self.num_heads, self.head_dim)\n",
172+
" v = v.view(-1, self.num_heads, self.head_dim)\n",
173+
"\n",
174+
" attn_out = varlen_attn(\n",
175+
" query=q,\n",
176+
" key=k,\n",
177+
" value=v,\n",
178+
" cu_seq_q=cu_seq,\n",
179+
" cu_seq_k=cu_seq,\n",
180+
" max_q=max_len,\n",
181+
" max_k=max_len,\n",
182+
" is_causal=True,\n",
183+
" )\n",
184+
" attn_out = attn_out.view(-1, self.embed_dim)\n",
185+
" attn_out = self.out_proj(attn_out)\n",
186+
" return attn_out.view(batch_size, seq_len, self.embed_dim)"
187+
]
188+
},
189+
{
190+
"cell_type": "markdown",
191+
"metadata": {},
192+
"source": [
193+
"We can also use `torch.compile` with `varlen_attn` and define \n",
194+
"\n",
195+
"```python \n",
196+
"compiled_varlen_attn: ClassVar[Callable] = torch.compile(\n",
197+
" varlen_attn, mode=\"max-autotune-no-cudagraphs\"\n",
198+
")\n",
199+
"```\n",
200+
"\n",
201+
"We can call `compiled_varlen_attn` instead of `varlen_attn` in the Attention forward, and everything else stays the same."
202+
]
203+
},
204+
{
205+
"cell_type": "markdown",
206+
"metadata": {},
207+
"source": [
208+
"Now, we can use this `SimpleVarlenAttention` module in a simple Transformer."
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"class SimpleVarlenTransformer(nn.Module):\n",
218+
" \"\"\"\n",
219+
" simple 1 layer transformer with varlen attention\n",
220+
" \"\"\"\n",
221+
"\n",
222+
" def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n",
223+
" super().__init__()\n",
224+
" self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n",
225+
" self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n",
226+
" self.norm = nn.LayerNorm(embed_dim)\n",
227+
"\n",
228+
" def forward(\n",
229+
" self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
230+
" ) -> torch.Tensor:\n",
231+
" x = self.tok_embeddings(tokens)\n",
232+
" x = x + self.attention(x, cu_seq, max_len)\n",
233+
" x = self.norm(x)\n",
234+
" return x"
235+
]
236+
},
237+
{
238+
"cell_type": "markdown",
239+
"metadata": {},
240+
"source": [
241+
"Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. "
242+
]
243+
},
244+
{
245+
"cell_type": "code",
246+
"execution_count": null,
247+
"metadata": {},
248+
"outputs": [],
249+
"source": [
250+
"def main():\n",
251+
" torch.manual_seed(42)\n",
252+
"\n",
253+
" batch_size = 3\n",
254+
" seq_len = 64\n",
255+
" vocab_size = 1000\n",
256+
" embed_dim = 128\n",
257+
" num_heads = 4\n",
258+
" eos_id = 2\n",
259+
" num_docs = 3\n",
260+
" device = \"cuda\"\n",
261+
" dtype = torch.bfloat16\n",
262+
"\n",
263+
" model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n",
264+
" device=device, dtype=dtype\n",
265+
" )\n",
266+
"\n",
267+
" # create input_batch tokens\n",
268+
" input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n",
269+
"\n",
270+
" for b in range(batch_size):\n",
271+
" # getting random positions to cut the input into multiple documents\n",
272+
" doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n",
273+
" for pos in doc_positions:\n",
274+
" input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n",
275+
" input_batch[b, -1] = eos_id\n",
276+
"\n",
277+
" cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n",
278+
" print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n",
279+
"\n",
280+
" # fwd pass\n",
281+
" output = model(input_batch, cu_seq, max_len)\n",
282+
" print(f\"output shape: {output.shape}\") # (3, 64, 128)\n",
283+
"\n",
284+
" # bwd pass\n",
285+
" loss = output.mean()\n",
286+
" loss.backward()\n",
287+
"\n",
288+
" print(f\"embedding grad shape: {model.tok_embeddings.weight.grad.shape}\") # (1000, 128)\n",
289+
" print(f\"embedding grad norm: {model.tok_embeddings.weight.grad.norm().item()}\")\n",
290+
"\n",
291+
"\n",
292+
"if __name__ == \"__main__\":\n",
293+
" main()"
294+
]
295+
}
296+
],
297+
"metadata": {
298+
"orig_nbformat": 4
299+
},
300+
"nbformat": 4,
301+
"nbformat_minor": 2
302+
}

0 commit comments

Comments
 (0)