Skip to content

Commit c879422

Browse files
committed
Editorial update
1 parent d458818 commit c879422

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

prototype_source/nestedtensor.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
3-
Nested Tensors
3+
Getting Started with Nested Tensors
44
===============================================================
55
66
Nested tensors generalize the shape of regular dense tensors, allowing for representation
@@ -20,7 +20,7 @@
2020
for operating on sequential data of varying lengths with a real-world example. In particular,
2121
they are invaluable for building transformers that can efficiently operate on ragged sequential
2222
inputs. Below, we present an implementation of multi-head attention using nested tensors that,
23-
combined usage of torch.compile, out-performs operating naively on tensors with padding.
23+
combined usage of ``torch.compile``, out-performs operating naively on tensors with padding.
2424
2525
Nested tensors are currently a prototype feature and are subject to change.
2626
"""
@@ -158,9 +158,9 @@
158158
# Further, not all operations have the same semnatics when applied to padded data.
159159
# For matrix multiplications in order to ignore the padded entries, one needs to pad
160160
# with 0 while for softmax one has to pad with -inf to ignore specific entries.
161-
# The ideal that nested tensor seeks to achieve is the ability to operate on ragged data
162-
# using the standard PyTorch tensor UX, avoiding inefficient and complicated
163-
# padding + masking.
161+
# The primary objective of nested tensor is to facilitate operations on ragged
162+
# data using the standard PyTorch tensor UX, thereby eliminating the need
163+
# for inefficient and complex padding and masking.
164164
padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
165165
[3.0, 4.0, 5.0]])
166166
print(F.softmax(padded_sentences_for_softmax, -1))
@@ -355,9 +355,17 @@ def benchmark(func, *args, **kwargs):
355355
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")
356356

357357
######################################################################
358-
# Note that without torch.compile, the overhead of the python subclass nested tensor
358+
# Note that without ``torch.compile``, the overhead of the python subclass nested tensor
359359
# can make it slower than the equivalent computation on padded tensors. However, once
360-
# torch.compile is enabled, operating on nested tensors gives a multiple x speedup.
360+
# ``torch.compile`` is enabled, operating on nested tensors gives a multiple x speedup.
361361
# Avoiding wasted computation on padding becomes only more valuable as the percentage
362362
# of padding in the batch increases.
363363
print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")
364+
365+
######################################################################
366+
# Conclusion
367+
# ----------
368+
# In this tutorial, we have learned how to perform basic operations with nested tensors and
369+
# how implement multi-head attention for transformers in a way that avoids computation on padding.
370+
# For more information, check out the docs for the
371+
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ namespace.

0 commit comments

Comments
 (0)