diff --git a/prototype_source/nestedtensor.py b/prototype_source/nestedtensor.py index 4462055b0c5..ecf099c1e02 100644 --- a/prototype_source/nestedtensor.py +++ b/prototype_source/nestedtensor.py @@ -1,36 +1,47 @@ """ -NestedTensors +Getting Started with Nested Tensors =============================================================== -NestedTensors are similar to regular tensors, except for their shape: +Nested tensors generalize the shape of regular dense tensors, allowing for representation +of ragged-sized data. -* for a regular tensor, each dimension has a size +* for a regular tensor, each dimension is regular and has a size -* for a nestedtensor, not all dimensions have regular sizes; some of them are jagged +* for a nested tensor, not all dimensions have regular sizes; some of them are ragged -Nestedtensors are a natural solution for representing sequential data within various domains: +Nested tensors are a natural solution for representing sequential data within various domains: -* in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor +* in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor -* in CV, images can have variable shapes, so a batch of images forms a nestedtensor +* in CV, images can have variable shapes, so a batch of images forms a nested tensor -In this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness -for operating on sequential data of varying lengths with a real-world example. +In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness +for operating on sequential data of varying lengths with a real-world example. In particular, +they are invaluable for building transformers that can efficiently operate on ragged sequential +inputs. Below, we present an implementation of multi-head attention using nested tensors that, +combined usage of ``torch.compile``, out-performs operating naively on tensors with padding. -NestedTensor are currently a prototype feature and are subject to change. +Nested tensors are currently a prototype feature and are subject to change. """ +import numpy as np +import timeit import torch import torch.nn.functional as F +from torch import nn + +torch.manual_seed(1) +np.random.seed(1) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### -# NestedTensor Initialization +# Nested tensor initialization # ---------------------------- # -# From the Python frontend, a nestedtensor can be created from a list of tensors. +# From the Python frontend, a nested tensor can be created from a list of tensors. # We denote nt[i] as the ith tensor component of a nestedtensor. nt = torch.nested.nested_tensor([torch.arange(12).reshape( 2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device) @@ -141,12 +152,15 @@ print(f"{nested_sentences=}") ###################################################################### -# This techinque of padding a batch of data to its max length is not optimal. +# This technique of padding a batch of data to its max length is not optimal. # The padded data is not needed for computation and wastes memory by allocating # larger tensors than necessary. # Further, not all operations have the same semnatics when applied to padded data. # For matrix multiplications in order to ignore the padded entries, one needs to pad # with 0 while for softmax one has to pad with -inf to ignore specific entries. +# The primary objective of nested tensor is to facilitate operations on ragged +# data using the standard PyTorch tensor UX, thereby eliminating the need +# for inefficient and complex padding and masking. padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]) print(F.softmax(padded_sentences_for_softmax, -1)) @@ -155,199 +169,83 @@ ###################################################################### # Let us take a look at a practical example: the multi-head attention component # utilized in `Transformers `__. -# The nestedtensor version is straightforward. -import math - -def mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, - W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, - b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, - dropout_p: float = 0.0) -> torch.Tensor: - """Compute multi-head attention with nested tensors. - Args: - query (torch.Tensor): query of shape (N, L_t, E_q) - key (torch.Tensor): key of shape (N, L_s, E_k) - value (torch.Tensor): value of shape (N, L_s, E_v) - nheads (int): number of heads in multi-head attention - W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) - W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) - W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) - W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) - b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None. - b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None. - b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None. - b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None. - dropout_p (float, optional): Dropout probability. Defaults to 0.0. - - Where: - N is the batch size - L_t is the target sequence length (jagged) - L_s is the source sequence length (jagged) - E_q is the embedding size for query - E_k is the embedding size for key - E_v is the embedding size for value - E_total is the embedding size for all heads combined - E_out is the output embedding size - Returns: - torch.Tensor: Output of shape (N, L_t, E_out) +# We can implement this in such a way that it can operate on either padded +# or nested tensors. +class MultiHeadAttention(nn.Module): """ - - N = query.size(0) - E_total = W_q.size(0) - assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" - E_head = E_total // nheads - - # apply input projection - # (N, L_t, E_q) -> (N, L_t, E_total) - query = F.linear(query, W_q, b_q) - # (N, L_s, E_k) -> (N, L_s, E_total) - key = F.linear(key, W_k, b_k) - # (N, L_s, E_v) -> (N, L_s, E_total) - value = F.linear(value, W_v, b_v) - - # reshape query, key, value to separate by head - # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) - query = query.reshape(N, -1, nheads, E_head).transpose(1, 2) - # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) - key = key.reshape(N, -1, nheads, E_head).transpose(1, 2) - # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) - value = value.reshape(N, -1, nheads, E_head).transpose(1, 2) - - # query matmul key^T - # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s) - keyT = key.transpose(-1, -2) - attn_weights = torch.matmul(query, keyT) - - # scale down - attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) - - # softmax - attn_weights = F.softmax(attn_weights, dim=-1) - - # dropout - if dropout_p > 0.0: - attn_weights = F.dropout(attn_weights, p=dropout_p) - - # attention_weights matmul value - # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head) - attn_output = torch.matmul(attn_weights, value) - - # merge heads - # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) - attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total) - - # apply output projection - # (N, L_t, E_total) -> (N, L_t, E_out) - attn_output = F.linear(attn_output, W_out, b_out) - - return attn_output - -###################################################################### -# The 0-padded tensor version additionally requires masks -# for more complicated treatments at padded entries. -def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, - attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor, - W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, - b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, - dropout_p: float = 0.0) -> torch.Tensor: - """Compute multi-head attention for padded out dense tensors. + Computes multi-head attention. Supports nested or padded tensors. Args: - query (torch.Tensor): query of shape (N, L_t, E_q) - key (torch.Tensor): key of shape (N, L_s, E_k) - value (torch.Tensor): value of shape (N, L_s, E_v) - nheads (int): number of heads in multi-head attention - attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) - attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) - W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) - W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) - W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) - W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) - b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None. - b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None. - b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None. - b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None. - dropout_p (float, optional): Dropout probability. Defaults to 0.0. - - Where: - N is the batch size - L_t is the target sequence length (padded) - L_s is the source sequence length (padded) - E_q is the embedding size for query - E_k is the embedding size for key - E_v is the embedding size for value - E_total is the embedding size for all heads combined - E_out is the output embedding size - Returns: - torch.Tensor: Output of shape (N, L_t, E_out) + E_q (int): Size of embedding dim for query + E_k (int): Size of embedding dim for key + E_v (int): Size of embedding dim for value + E_total (int): Total embedding dim of combined heads post input projection. Each head + has dim E_total // nheads + nheads (int): Number of heads + dropout_p (float, optional): Dropout probability. Default: 0.0 """ - N = query.size(0) - L_t = query.size(1) - L_s = key.size(1) - E_total = W_q.size(0) - assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" - assert L_t == L_s, "This implementation assumes equal query and key sequence lengths" - E_head = E_total // nheads - - # apply input projection - # (N, L_t, E_q) -> (N, L_t, E_total) - query = F.linear(query, W_q, b_q) - # (N, L_s, E_k) -> (N, L_s, E_total) - key = F.linear(key, W_k, b_k) - # (N, L_s, E_v) -> (N, L_s, E_total) - value = F.linear(value, W_v, b_v) - - # reshape query, key, value to separate by head - # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head) - query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) - # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) - key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) - # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) - value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) - - # query bmm key^T - # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s) - keyT = key.transpose(-1, -2) - attn_weights = torch.bmm(query, keyT) - - # scale down - attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) - - # Have to manipulate masks in order to apply them to the attention weights - key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device) - attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32) - attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf")) - - # Zero out the attention weights where the mask is True by adding -inf prior to softmax - attn_weights.add_(attn_mask) - - # softmax - attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0) - - # dropout - if dropout_p > 0.0: - attn_weights = F.dropout(attn_weights, p=dropout_p) - - # attention_weights bmm value - # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head) - attn_output = attn_weights.bmm(value) - - # merge heads - # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) - attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total) - - # apply output projection - # (N, L_t, E_total) -> (N, L_t, E_out) - attn_output = F.linear(attn_output, W_out, b_out) - - # padding-specific step: remove output projection bias from padded entries - attn_output[attn_mask_q, :] = 0.0 - - return attn_output + def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int, + nheads: int, dropout_p: float = 0.0): + super().__init__() + self.nheads = nheads + self.dropout_p = dropout_p + self.query_proj = nn.Linear(E_q, E_total) + self.key_proj = nn.Linear(E_k, E_total) + self.value_proj = nn.Linear(E_v, E_total) + E_out = E_q + self.out_proj = nn.Linear(E_total, E_out) + assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" + self.E_head = E_total // nheads + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + """ + Forward pass; runs the following process: + 1. Apply input projection + 2. Split heads and prepare for SDPA + 3. Run SDPA + 4. Apply output projection + + Args: + query (torch.Tensor): query of shape (N, L_t, E_q) + key (torch.Tensor): key of shape (N, L_s, E_k) + value (torch.Tensor): value of shape (N, L_s, E_v) + + Returns: + attn_output (torch.Tensor): output of shape (N, L_t, E_q) + """ + # Step 1. Apply input projection + # TODO: demonstrate packed projection + query = self.query_proj(query) + key = self.key_proj(key) + value = self.value_proj(value) + + # Step 2. Split heads and prepare for SDPA + # reshape query, key, value to separate by head + # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) + query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) + key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) + value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + + # Step 3. Run SDPA + # (N, nheads, L_t, E_head) + attn_output = F.scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, is_causal=True) + # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) + attn_output = attn_output.transpose(1, 2).flatten(-2) + + # Step 4. Apply output projection + # (N, L_t, E_total) -> (N, L_t, E_out) + attn_output = self.out_proj(attn_output) + + return attn_output ###################################################################### # set hyperparameters following `the Transformer paper `__ N = 512 -E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512 +E_q, E_k, E_v, E_total = 512, 512, 512, 512 +E_out = E_q nheads = 8 ###################################################################### @@ -356,9 +254,7 @@ def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nhea ###################################################################### # Let us generate some realistic fake data from Zipf's law. -import numpy as np - -def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: +def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 sentence_lengths = np.empty(batch_size, dtype=int) @@ -368,124 +264,108 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: while word != 3 and word != 386 and word != 858: sentence_lengths[ibatch] += 1 word = np.random.zipf(alpha) - return sentence_lengths + return torch.tensor(sentence_lengths) -alpha = 1.2 +###################################################################### +# Create nested tensor batch inputs +def gen_batch(N, E_q, E_k, E_v, device): + # generate semi-realistic data using Zipf distribution for sentence lengths + sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N) -sentence_lengths = zipf_sentence_lengths(alpha, N) -L_t = np.max(sentence_lengths) -L_s = L_t + # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged + # dimension and works with torch.compile. The batch items each have shape (B, S*, D) + # where B = batch size, S* = ragged sequence length, and D = embedding dimension. + query = torch.nested.nested_tensor([ + torch.randn(l.item(), E_q, device=device) + for l in sentence_lengths + ], layout=torch.jagged) -###################################################################### -# create inputs - -# create parameters -W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device) -W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device) -W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device) -W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device) - -# create nested input -queries = [] -keys = [] -values = [] -for i in range(N): - l = sentence_lengths[i] - s = l - queries.append(torch.randn((l, E_q), device=device)) - keys .append(torch.randn((s, E_k), device=device)) - values .append(torch.randn((s, E_v), device=device)) -query = torch.nested.nested_tensor(queries) -key = torch.nested.nested_tensor(keys) -value = torch.nested.nested_tensor(values) - -# pad input -padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q)) -padded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k)) -padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v)) - -# create attention masks -attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool) -attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool) - -# We need to mask out the padding entries in the attention weights. -for i, entry_length in enumerate(sentence_lengths): - attn_mask_q[i, entry_length:] = True - attn_mask_kv[i, entry_length:] = True + key = torch.nested.nested_tensor([ + torch.randn(s.item(), E_k, device=device) + for s in sentence_lengths + ], layout=torch.jagged) + + value = torch.nested.nested_tensor([ + torch.randn(s.item(), E_v, device=device) + for s in sentence_lengths + ], layout=torch.jagged) + + return query, key, value, sentence_lengths + +query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device) ###################################################################### -# check correctness and performance +# Generate padded forms of query, key, value for comparison +def jagged_to_padded(jt, padding_val): + # TODO: do jagged -> padded directly when this is supported + return torch.nested.to_padded_tensor( + torch.nested.nested_tensor(list(jt.unbind())), + padding_val) -import timeit +padded_query, padded_key, padded_value = ( + jagged_to_padded(t, 0.0) for t in (query, key, value) +) -t0 = timeit.default_timer() -out_nested = mha_nested( - query, key, value, nheads, - W_q, W_k, W_v, W_out, - b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, - dropout_p=dropout_p) - -t1 = timeit.default_timer() -out_padded = mha_padded( - padded_query, padded_key, padded_value, nheads, - attn_mask_q, attn_mask_kv, - W_q, W_k, W_v, W_out, - b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, - dropout_p=dropout_p) -t2 = timeit.default_timer() - -print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) -print("nestedtensor multi-head attention takes", t1 - t0, "seconds") -print("padded tensor multi-head attention takes", t2 - t1, "seconds") +###################################################################### +# Construct the model +mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device) ###################################################################### -# Although the nestedtensor version avoids wasted computation on padding, it is not faster -# then the equivalent padded tensor version. This is because the nestedtensor version -# has implemented a few of the kernels, like softmax, in a non optimal way. -# -# There are plans to implement performance critical operations using the new Pytorch 2.0 stack -# For now, some performant kernels are provided for specific use cases, e.g. -# self-attention evaluation by multi-head attention formula. +# Check correctness and performance +def benchmark(func, *args, **kwargs): + torch.cuda.synchronize() + begin = timeit.default_timer() + output = func(*args, **kwargs) + torch.cuda.synchronize() + end = timeit.default_timer() + return output, (end - begin) + +output_nested, time_nested = benchmark(mha, query, key, value) +output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value) + +# padding-specific step: remove output projection bias from padded entries for fair comparison +for i, entry_length in enumerate(sentence_lengths): + output_padded[i, entry_length:] = 0.0 + +print("=== without torch.compile ===") +print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item()) +print("nested tensor multi-head attention takes", time_nested, "seconds") +print("padded tensor multi-head attention takes", time_padded, "seconds") + +# warm up compile first... +compiled_mha = torch.compile(mha) +compiled_mha(query, key, value) +# ...now benchmark +compiled_output_nested, compiled_time_nested = benchmark( + compiled_mha, query, key, value) + +# warm up compile first... +compiled_mha(padded_query, padded_key, padded_value) +# ...now benchmark +compiled_output_padded, compiled_time_padded = benchmark( + compiled_mha, padded_query, padded_key, padded_value) + +# padding-specific step: remove output projection bias from padded entries for fair comparison +for i, entry_length in enumerate(sentence_lengths): + compiled_output_padded[i, entry_length:] = 0.0 -# embeddings are assumed to be the same -E = E_total -mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device) -mha_lib.eval() +print("=== with torch.compile ===") +print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item()) +print("nested tensor multi-head attention takes", compiled_time_nested, "seconds") +print("padded tensor multi-head attention takes", compiled_time_padded, "seconds") ###################################################################### -# extract parameters for correctness check -mha_lib.in_proj_weight.requires_grad_(False) -mha_lib.in_proj_bias.requires_grad_(False) -mha_lib.out_proj.weight.requires_grad_(False) -mha_lib.out_proj.bias.requires_grad_(False) -W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E] -W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E] -W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :] -W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias +# Note that without ``torch.compile``, the overhead of the python subclass nested tensor +# can make it slower than the equivalent computation on padded tensors. However, once +# ``torch.compile`` is enabled, operating on nested tensors gives a multiple x speedup. +# Avoiding wasted computation on padding becomes only more valuable as the percentage +# of padding in the batch increases. +print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}") ###################################################################### -# If we set need_weights to False this will enable the fast path in the library. -# Under the hood this will call _scaled_dot_product_attention. If your tensors -# are on CUDA, than a fused, efficient attention kernel will be used. For -# more detailed performance characteristics look at the benchmark in -# pytorch/benchmarks/transformer/sdp.py - -with torch.inference_mode(): - t0 = timeit.default_timer() - out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False) - - t1 = timeit.default_timer() - padded_out = mha_padded( - padded_query, padded_query, padded_query, nheads, - attn_mask_q, attn_mask_q, - W_q, W_k, W_v, W_out, - b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, - dropout_p=dropout_p) - t2 = timeit.default_timer() - -nested_time = t1 - t0 -padded_time = t2 - t1 -print("Nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item()) -print("Nested library multi-head attention takes", nested_time, "seconds") -print("Padded tensor multi-head attention takes", padded_time, "seconds") -print(f"Nested Speedup: {padded_time / nested_time:.3f}") +# Conclusion +# ---------- +# In this tutorial, we have learned how to perform basic operations with nested tensors and +# how implement multi-head attention for transformers in a way that avoids computation on padding. +# For more information, check out the docs for the +# `torch.nested `__ namespace.