Skip to content

Commit 57e83b6

Browse files
committed
fix
1 parent 32f6dee commit 57e83b6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from onnx_diagnostic.torch_models.llms import get_phi2
1010
from onnx_diagnostic.helpers import string_type
11+
from onnx_diagnostic.torch_export_patches import torch_export_patches
1112

1213

1314
class TestLlmPhi(ExtTestCase):
@@ -27,7 +28,8 @@ def test_export_phi2_1(self):
2728
self.assertEqual(
2829
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
2930
)
30-
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
31+
with torch_export_patches(patch_transformers=True):
32+
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
3133
assert ep
3234

3335

0 commit comments

Comments
 (0)