Skip to content

Commit 58a9c14

Browse files
committed
fix bypass
1 parent d87d9c6 commit 58a9c14

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

_unittests/ut_torch_models/test_llms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ def test_export_tiny_llm_2_bypassed(self):
3030
model, inputs = data["model"], data["inputs"]
3131
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
3232
with bypass_export_some_errors(
33-
patch_transformers=True, replace_dynamic_cache=True, verbose=10
34-
):
33+
patch_transformers=True, replace_dynamic_cache=True
34+
) as modificator:
35+
inputs = modificator(inputs)
3536
ep = torch.export.export(
3637
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
3738
)

_unittests/ut_torch_models/test_llms_onnx.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ def test_onnx_export_tiny_llm(self):
1818
data = get_tiny_llm()
1919
model, inputs = data["model"], data["inputs"]
2020
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
21-
with bypass_export_some_errors(patch_transformers=True, replace_dynamic_cache=True):
21+
with bypass_export_some_errors(
22+
patch_transformers=True, replace_dynamic_cache=True
23+
) as modificator:
24+
inputs = modificator(inputs)
2225
ep = torch.onnx.export(
2326
model,
2427
(),
@@ -38,7 +41,10 @@ def test_onnx_export_tiny_llm_cdbg(self):
3841
data = get_tiny_llm()
3942
model, inputs = data["model"], data["inputs"]
4043
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
41-
with bypass_export_some_errors(patch_transformers=True, replace_dynamic_cache=True):
44+
with bypass_export_some_errors(
45+
patch_transformers=True, replace_dynamic_cache=True
46+
) as modificator:
47+
inputs = modificator(inputs)
4248
onx = to_onnx(model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"])
4349
self.assert_onnx_disc(
4450
inspect.currentframe().f_code.co_name, onx, model, inputs, verbose=1

0 commit comments

Comments
 (0)