|
1 | 1 | """
|
2 | 2 |
|
3 |
| -Nested Tensors |
| 3 | +Getting Started with Nested Tensors |
4 | 4 | ===============================================================
|
5 | 5 |
|
6 | 6 | Nested tensors generalize the shape of regular dense tensors, allowing for representation
|
|
20 | 20 | for operating on sequential data of varying lengths with a real-world example. In particular,
|
21 | 21 | they are invaluable for building transformers that can efficiently operate on ragged sequential
|
22 | 22 | inputs. Below, we present an implementation of multi-head attention using nested tensors that,
|
23 |
| -combined usage of torch.compile, out-performs operating naively on tensors with padding. |
| 23 | +combined usage of ``torch.compile``, out-performs operating naively on tensors with padding. |
24 | 24 |
|
25 | 25 | Nested tensors are currently a prototype feature and are subject to change.
|
26 | 26 | """
|
|
158 | 158 | # Further, not all operations have the same semnatics when applied to padded data.
|
159 | 159 | # For matrix multiplications in order to ignore the padded entries, one needs to pad
|
160 | 160 | # with 0 while for softmax one has to pad with -inf to ignore specific entries.
|
161 |
| -# The ideal that nested tensor seeks to achieve is the ability to operate on ragged data |
162 |
| -# using the standard PyTorch tensor UX, avoiding inefficient and complicated |
163 |
| -# padding + masking. |
| 161 | +# The primary objective of nested tensor is to facilitate operations on ragged |
| 162 | +# data using the standard PyTorch tensor UX, thereby eliminating the need |
| 163 | +# for inefficient and complex padding and masking. |
164 | 164 | padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
|
165 | 165 | [3.0, 4.0, 5.0]])
|
166 | 166 | print(F.softmax(padded_sentences_for_softmax, -1))
|
@@ -355,9 +355,17 @@ def benchmark(func, *args, **kwargs):
|
355 | 355 | print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")
|
356 | 356 |
|
357 | 357 | ######################################################################
|
358 |
| -# Note that without torch.compile, the overhead of the python subclass nested tensor |
| 358 | +# Note that without ``torch.compile``, the overhead of the python subclass nested tensor |
359 | 359 | # can make it slower than the equivalent computation on padded tensors. However, once
|
360 |
| -# torch.compile is enabled, operating on nested tensors gives a multiple x speedup. |
| 360 | +# ``torch.compile`` is enabled, operating on nested tensors gives a multiple x speedup. |
361 | 361 | # Avoiding wasted computation on padding becomes only more valuable as the percentage
|
362 | 362 | # of padding in the batch increases.
|
363 | 363 | print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")
|
| 364 | + |
| 365 | +###################################################################### |
| 366 | +# Conclusion |
| 367 | +# ---------- |
| 368 | +# In this tutorial, we have learned how to perform basic operations with nested tensors and |
| 369 | +# how implement multi-head attention for transformers in a way that avoids computation on padding. |
| 370 | +# For more information, check out the docs for the |
| 371 | +# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ namespace. |
0 commit comments