Skip to content

Commit 05d7dc2

Browse files
committed
Fix NestedTensor deprecation warning in SDPA tutorial
Add layout=torch.jagged parameter to nested_tensor call to suppress the prototype API warning and use the recommended jagged layout.
1 parent f99e9e8 commit 05d7dc2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def generate_rand_batch(
223223
torch.randn(seq_len, embed_dimension,
224224
dtype=dtype, device=device)
225225
for seq_len in seq_len_list
226-
]
226+
], layout=torch.jagged
227227
),
228228
seq_len_list,
229229
)

0 commit comments

Comments
 (0)