Skip to content

Commit 3323490

Browse files
committed
improve patch for attention
1 parent 175a800 commit 3323490

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
66
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
77
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
8+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
9+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
810

911

1012
class TestPatchPatchTransformers(ExtTestCase):
@@ -121,6 +123,57 @@ def test_causal_mask_in_scaled_dot_product_attention(self):
121123
attn_causal_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
122124
self.assertEqual(attn_causal_bias.min().item(), -float("inf"))
123125

126+
# @ignore_warnings(UserWarning)
127+
def test_causal_mask_in_scaled_dot_product_attention_export(self):
128+
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
129+
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
130+
kwargs = {
131+
"module": None,
132+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
133+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
134+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
135+
"attention_mask": None,
136+
"attention_dropout": 0,
137+
"scaling": 0.10206207261596575,
138+
"is_causal": True,
139+
}
140+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
141+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
142+
self.assertEqualArray(expected, got)
143+
144+
class Model(torch.nn.Module):
145+
def forward(self, query, key, value):
146+
kwargs = {
147+
"module": None,
148+
"query": query,
149+
"key": key,
150+
"value": value,
151+
"attention_mask": None,
152+
"attention_dropout": 0,
153+
"scaling": 0.10206207261596575,
154+
"is_causal": True,
155+
}
156+
return patched_sdpa_attention_forward(**kwargs)[0]
157+
158+
query, key, value = kwargs["query"], kwargs["key"], kwargs["value"]
159+
model = Model()
160+
got = model(query, key, value)
161+
self.assertEqualArray(expected, got)
162+
163+
# static export
164+
ep = torch.export.export(model, (query, key, value))
165+
got = ep.module()(query, key, value)
166+
self.assertEqualArray(expected, got)
167+
168+
# dynamic
169+
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
170+
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
171+
print("****", fake_inputs)
172+
epd = torch.export.export(model, fake_inputs) # , dynamic_shapes=use_dyn_not_str(ds))
173+
print(epq)
174+
got = epd.module()(query, key, value)
175+
self.assertEqualArray(expected, got)
176+
124177

125178
if __name__ == "__main__":
126179
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,14 @@ def patched_sdpa_attention_forward(
13741374
attention_mask is None or attention_mask.shape[3] == key.shape[2],
13751375
"Attention mask shape incompatible with key shape.",
13761376
)
1377+
torch._check(
1378+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
1379+
lambda: f"broadcast issue query (1): {query.shape}, key: {key.shape}, value: {value.shape}",
1380+
)
1381+
torch._check(
1382+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
1383+
lambda: f"broadcast issue query (2): {query.shape}, key: {key.shape}, value: {value.shape}",
1384+
)
13771385
if is_causal:
13781386
attn_output = torch.cond(
13791387
query.shape[2] > 1, # distinction between prefill and decoding steps

0 commit comments

Comments
 (0)