Skip to content

Commit 9b86d3e

Browse files
committed
fix patch
1 parent 5a4c01c commit 9b86d3e

File tree

3 files changed

+83
-6
lines changed

3 files changed

+83
-6
lines changed

_doc/recipes/plot_export_dim1.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import torch
1515
from onnx_diagnostic import doc
16+
from onnx_diagnostic.helpers import string_type
17+
from onnx_diagnostic.torch_export_patches import torch_export_patches
1618

1719

1820
class Model(torch.nn.Module):
@@ -29,21 +31,28 @@ def forward(self, x, y, z):
2931
DYN = torch.export.Dim.DYNAMIC
3032
ds = {0: DYN, 1: DYN}
3133

34+
print("-- export shape:", string_type((x, y, z), with_shape=True))
35+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
36+
3237
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
33-
print(ep.graph)
38+
print(ep)
3439

3540
# %%
3641
# Same model, a dynamic dimension = 1
3742
# +++++++++++++++++++++++++++++++++++
3843

44+
3945
z = z[:1]
4046

4147
DYN = torch.export.Dim.DYNAMIC
4248
ds = {0: DYN, 1: DYN}
4349

50+
print("-- export shape:", string_type((x, y, z), with_shape=True))
51+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
52+
4453
try:
4554
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
46-
print(ep.graph)
55+
print(ep)
4756
except Exception as e:
4857
print("ERROR", e)
4958

@@ -54,14 +63,33 @@ def forward(self, x, y, z):
5463
# Same model, a dynamic dimension = 1 and backed_size_oblivious=True
5564
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5665

66+
print("-- export shape:", string_type((x, y, z), with_shape=True))
67+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
68+
5769
try:
5870
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
5971
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
60-
print(ep.graph)
72+
print(ep)
6173
except RuntimeError as e:
6274
print("ERROR", e)
6375

76+
77+
# %%
78+
# Final try with pathes...
79+
# ++++++++++++++++++++++++
80+
81+
print("-- export shape:", string_type((x, y, z), with_shape=True))
82+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
83+
84+
with torch_export_patches(patch_torch=1):
85+
try:
86+
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
87+
print(ep)
88+
except RuntimeError as e:
89+
print("ERROR", e)
90+
6491
# %%
65-
# It worked.
92+
# It is difficult to find the good option. It is possible on a simple model
93+
# but sometimes impossible on a bigger model mixing different shapes.
6694

6795
doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_causal_mask_in_scaled_dot_product_attention(self):
124124
self.assertEqual(attn_causal_bias.min().item(), -float("inf"))
125125

126126
@ignore_warnings(UserWarning)
127-
def test_causal_mask_in_scaled_dot_product_attention_export(self):
127+
def test_sdpa_attention_forward_export_is_causal(self):
128128
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
129129
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
130130
kwargs = {
@@ -172,6 +172,55 @@ def forward(self, query, key, value):
172172
got = epd.module()(query, key, value)
173173
self.assertEqualArray(expected, got)
174174

175+
@ignore_warnings(UserWarning)
176+
def test_sdpa_attention_forward_export_is_causal_none(self):
177+
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
178+
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
179+
kwargs = {
180+
"module": None,
181+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
182+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
183+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
184+
"attention_mask": None,
185+
"attention_dropout": 0,
186+
"scaling": 0.10206207261596575,
187+
"is_causal": None,
188+
}
189+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
190+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
191+
self.assertEqualArray(expected, got)
192+
193+
class Model(torch.nn.Module):
194+
def forward(self, query, key, value):
195+
kwargs = {
196+
"module": None,
197+
"query": query,
198+
"key": key,
199+
"value": value,
200+
"attention_mask": None,
201+
"attention_dropout": 0,
202+
"scaling": 0.10206207261596575,
203+
"is_causal": None,
204+
}
205+
return patched_sdpa_attention_forward(**kwargs)[0]
206+
207+
query, key, value = kwargs["query"], kwargs["key"], kwargs["value"]
208+
model = Model()
209+
got = model(query, key, value)
210+
self.assertEqualArray(expected, got)
211+
212+
# static export
213+
ep = torch.export.export(model, (query, key, value))
214+
got = ep.module()(query, key, value)
215+
self.assertEqualArray(expected, got)
216+
217+
# dynamic
218+
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
219+
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
220+
epd = torch.export.export(model, fake_inputs, dynamic_shapes=use_dyn_not_str(ds))
221+
got = epd.module()(query, key, value)
222+
self.assertEqualArray(expected, got)
223+
175224

176225
if __name__ == "__main__":
177226
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def patched_sdpa_attention_forward(
13941394
f"value: {value.shape}"
13951395
),
13961396
)
1397-
if not is_causal:
1397+
if not is_causal or not patch_is_causal:
13981398
return (
13991399
torch.nn.functional.scaled_dot_product_attention(
14001400
query,

0 commit comments

Comments
 (0)