Skip to content

Commit a9c08f9

Browse files
committed
fixes
1 parent 69df2e4 commit a9c08f9

File tree

6 files changed

+14
-24
lines changed

6 files changed

+14
-24
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.10
55
++++++
66

7+
* :pr:`218`: patches used sdpa_mask_recent_torch used from _vmap_for_bhqkv
8+
79
0.7.9
810
+++++
911

_unittests/ut_export/test_jit.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
ignore_warnings,
88
requires_onnxscript,
99
)
10-
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
1110
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting
1211

1312
try:
@@ -75,19 +74,6 @@ def forward(self, images, position):
7574
y = torch.arange(5, dtype=torch.int64) + 1
7675
expected = model(x, y)
7776

78-
name = self.get_dump_file("test_export_loop_onnxscript.onnx")
79-
torch.onnx.export(
80-
model,
81-
(x, y),
82-
name,
83-
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
84-
dynamo=False,
85-
)
86-
ref = ExtendedReferenceEvaluator(name)
87-
feeds = dict(images=x.numpy(), position=y.numpy())
88-
got = ref.run(None, feeds)[0]
89-
self.assertEqualArray(expected, got)
90-
9177
DYN = torch.export.Dim.DYNAMIC
9278
ep = torch.export.export(
9379
model,

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
1515
torch_export_patches,
1616
)
17+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1718
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1819

1920

@@ -305,7 +306,7 @@ def test_phi2_export_module(self):
305306
model,
306307
(),
307308
kwargs=inputs,
308-
dynamic_shapes=dyn_shapes,
309+
dynamic_shapes=use_dyn_not_str(dyn_shapes),
309310
strict=False, # True works but then the it fails during the execution
310311
)
311312
# ep = ep.run_decompositions()
@@ -343,7 +344,7 @@ def test_phi2_export_interpreter(self):
343344
model,
344345
(),
345346
kwargs=inputs,
346-
dynamic_shapes=dyn_shapes,
347+
dynamic_shapes=use_dyn_not_str(dyn_shapes),
347348
strict=False, # True works but then the it fails during the execution
348349
)
349350
# ep = ep.run_decompositions()

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@ def loop_body_1(z, iv, x, y):
604604
rewritten_expected2 = RewrittenModel2()(x, y)
605605
self.assertEqualArray(expected, rewritten_expected2)
606606

607-
if not has_torch("2.9"):
608-
raise unittest.SkipTest("skipped export, torch must be >= 2.9")
607+
if not has_torch("2.10"):
608+
raise unittest.SkipTest("skipped export, torch must be >= 2.10")
609609

610610
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
611611
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_vmap(self):
1717
got = patched_vmap(f)(x, y)
1818
self.assertEqualArray(expected, got)
1919

20-
@requires_torch("2.9")
20+
@requires_torch("2.10")
2121
def test_export_vmap(self):
2222
class Model(torch.nn.Module):
2323
def forward(self, x, y):
@@ -206,10 +206,11 @@ def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callabl
206206

207207
class Model(torch.nn.Module):
208208
def forward(self, batch_arange, head_arange, cache_position, kv_arange):
209-
with TransformGetItemToIndex():
210-
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
211-
batch_arange, head_arange, cache_position, kv_arange
212-
)
209+
# with TransformGetItemToIndex():
210+
# This context as ignored in 2.8 and not any more in 2.9.
211+
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
212+
batch_arange, head_arange, cache_position, kv_arange
213+
)
213214
return causal_mask2
214215

215216
inputs = batch_arange, head_arange, cache_position, kv_arange

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_f_validate_model_onnx_dynamo_ir(self):
9696
)
9797

9898
@requires_torch("2.7")
99-
@requires_onnxscript("0.5")
99+
@requires_onnxscript("0.7")
100100
@hide_stdout()
101101
@ignore_warnings(FutureWarning)
102102
def test_g_validate_model_onnx_dynamo_os_ort(self):

0 commit comments

Comments
 (0)