Skip to content

Commit 5946a0e

Browse files
committed
fix test
1 parent 930b984 commit 5946a0e

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def test_bypass_onnx_export_tiny_llm_official_nopositionids(self):
9797
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
9898
@hide_stdout()
9999
def test_bypass_onnx_export_tiny_llm_official_full(self):
100+
try:
101+
from experimental_experiment.torch_interpreter import to_onnx
102+
except ImportError:
103+
to_onnx = None
104+
100105
data = get_tiny_llm()
101106
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
102107
self.assertEqual(
@@ -106,22 +111,25 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
106111
patch_transformers=True, verbose=1, stop_if_static=1
107112
) as modificator:
108113
new_inputs = modificator(copy.deepcopy(inputs))
109-
ep = torch.onnx.export(
110-
model,
111-
(),
112-
kwargs=new_inputs,
113-
dynamic_shapes=ds,
114-
dynamo=True,
115-
optimize=True,
116-
report=True,
117-
verify=False,
118-
)
114+
if to_onnx:
115+
proto = to_onnx(model, (), kwargs=new_inputs, dynamic_shapes=ds)
116+
else:
117+
proto = torch.onnx.export(
118+
model,
119+
(),
120+
kwargs=new_inputs,
121+
dynamic_shapes=ds,
122+
dynamo=True,
123+
optimize=True,
124+
report=True,
125+
verify=False,
126+
).model_proto
119127
# There are some discrepancies with torch==2.6
120128
if not has_torch("2.7"):
121129
raise unittest.SkipTest("discrepancies observed with torch<2.7")
122130
self.assert_onnx_disc(
123131
inspect.currentframe().f_code.co_name,
124-
ep.model_proto,
132+
proto,
125133
model,
126134
inputs,
127135
verbose=1,

onnx_diagnostic/helpers/log_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def enumerate_csv_files(
7474
print(
7575
f"[enumerate_csv_files] data[{itn}][{ii}] is a csv file: {name!r}]"
7676
)
77-
with zf.open(name) as f:
78-
line = f.readline()
77+
with zf.open(name) as zzf:
78+
line = zzf.readline()
7979
yield (
8080
os.path.split(name)[-1],
8181
"%04d-%02d-%02d %02d:%02d:%02d" % info.date_time,

0 commit comments

Comments
 (0)