Skip to content

Commit 8ad4bba

Browse files
committed
mypy
1 parent 82d4cd7 commit 8ad4bba

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ def test_onnx_export_tiny_llm_official(self):
2828
self.assertEqual(
2929
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3030
)
31-
ep = torch.onnx.export(
32-
model,
33-
(),
34-
kwargs=inputs,
35-
dynamic_shapes=data["dynamic_shapes"],
36-
dynamo=True,
37-
optimize=True,
38-
)
31+
with torch_export_patches(patch_transformers=True):
32+
ep = torch.onnx.export(
33+
model,
34+
(),
35+
kwargs=inputs,
36+
dynamic_shapes=data["dynamic_shapes"],
37+
dynamo=True,
38+
optimize=True,
39+
)
3940
# There are some discrepancies with torch==2.6
4041
if not has_torch("2.7"):
4142
raise unittest.SkipTest("discrepancies observed with torch<2.7")

onnx_diagnostic/helpers/graph_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def text_rendering(self, prefix="") -> str:
290290
*self.nodes,
291291
*[oh.make_node(i, [i], ["END"]) for i in self.output_names],
292292
]
293-
existing = set(self.start_names) - set(self.input_names)
294-
existing |= {"BEGIN"}
295-
existing = sorted(existing)
293+
exist = set(self.start_names) - set(self.input_names)
294+
exist |= {"BEGIN"}
295+
existing = sorted(exist)
296296
order = self.computation_order(nodes, existing)
297297
positions = self.graph_positions(nodes, order, existing)
298298
text_pos = self.text_positions(nodes, positions)

0 commit comments

Comments
 (0)