Skip to content

Commit ec40294

Browse files
committed
install
1 parent c37e666 commit ec40294

File tree

7 files changed

+77
-41
lines changed

7 files changed

+77
-41
lines changed

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ def test_image_to_video_oblivious(self):
5454
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
5555
model(**inputs)
5656
model(**data["inputs2"])
57-
with torch.fx.experimental._config.patch(
58-
backed_size_oblivious=True
59-
), torch_export_patches(
60-
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
57+
with (
58+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
59+
torch_export_patches(
60+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
61+
),
6162
):
6263
torch.export.export(
6364
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False

_unittests/ut_tasks/try_tasks.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -917,13 +917,18 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
917917
# use_cache:bool,logits_to_keep:None,return_dict:bool)
918918

919919
print()
920-
with torch_export_patches(
921-
patch_torch=False, patch_sympy=False, patch_transformers=True
922-
), steal_forward(
923-
model,
924-
dump_file=self.get_dump_file("test_imagetext2text_generation_gemma3_4b_it.onnx"),
925-
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
926-
save_as_external_data=False,
920+
with (
921+
torch_export_patches(
922+
patch_torch=False, patch_sympy=False, patch_transformers=True
923+
),
924+
steal_forward(
925+
model,
926+
dump_file=self.get_dump_file(
927+
"test_imagetext2text_generation_gemma3_4b_it.onnx"
928+
),
929+
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
930+
save_as_external_data=False,
931+
),
927932
):
928933
generated_ids = model.generate(
929934
**inputs,

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,9 @@ def forward(self, x, ind1, ind2):
309309
with self.subTest(
310310
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
311311
):
312-
with torch_export_patches(), torch.fx.experimental._config.patch(
313-
backed_size_oblivious=True
312+
with (
313+
torch_export_patches(),
314+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
314315
):
315316
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
316317
got = ep.module()(*inputs)

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def test_export_phi2_1_batch_size_1_oblivious(self):
3333
self.assertEqual(
3434
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3535
)
36-
with torch.fx.experimental._config.patch(
37-
backed_size_oblivious=True
38-
), torch_export_patches(patch_transformers=True):
36+
with (
37+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
38+
torch_export_patches(patch_transformers=True),
39+
):
3940
ep = torch.export.export(
4041
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
4142
)

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ def _make_exporter_export(
213213
backed_size_oblivious=backed_size_oblivious,
214214
)
215215
else:
216-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
217-
io.StringIO()
216+
with (
217+
contextlib.redirect_stdout(io.StringIO()),
218+
contextlib.redirect_stderr(io.StringIO()),
218219
):
219220
exported = _wrap_torch_export(
220221
model,
@@ -260,8 +261,9 @@ def _make_exporter_export(
260261
else exported.run_decompositions({})
261262
)
262263
else:
263-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
264-
io.StringIO()
264+
with (
265+
contextlib.redirect_stdout(io.StringIO()),
266+
contextlib.redirect_stderr(io.StringIO()),
265267
):
266268
exported = _wrap_torch_export(
267269
model,
@@ -295,8 +297,9 @@ def _make_exporter_export(
295297
graph = CustomTracer().trace(model)
296298
mod = torch.fx.GraphModule(model, graph)
297299
else:
298-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
299-
io.StringIO()
300+
with (
301+
contextlib.redirect_stdout(io.StringIO()),
302+
contextlib.redirect_stderr(io.StringIO()),
300303
):
301304
graph = CustomTracer().trace(model)
302305
mod = torch.fx.GraphModule(model, graph)
@@ -341,8 +344,9 @@ def _make_exporter_onnx(
341344
return_builder=True,
342345
)
343346
else:
344-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
345-
io.StringIO()
347+
with (
348+
contextlib.redirect_stdout(io.StringIO()),
349+
contextlib.redirect_stderr(io.StringIO()),
346350
):
347351
onx, builder = to_onnx(
348352
model,
@@ -375,8 +379,9 @@ def _make_exporter_onnx(
375379
report=True,
376380
).model_proto
377381
else:
378-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
379-
io.StringIO()
382+
with (
383+
contextlib.redirect_stdout(io.StringIO()),
384+
contextlib.redirect_stderr(io.StringIO()),
380385
):
381386
onx = torch.onnx.export(
382387
model,
@@ -410,8 +415,9 @@ def _make_exporter_onnx(
410415
ep.optimize()
411416
onx = ep.model_proto
412417
else:
413-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
414-
io.StringIO()
418+
with (
419+
contextlib.redirect_stdout(io.StringIO()),
420+
contextlib.redirect_stderr(io.StringIO()),
415421
):
416422
ep = torch.onnx.export(
417423
model,

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -268,19 +268,22 @@ def torch_export_patches(
268268
if rewrite:
269269
from .patch_module import torch_export_rewrite
270270

271-
with torch_export_rewrite(
272-
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
273-
), torch_export_patches( # type: ignore[var-annotated]
274-
patch_sympy=patch_sympy,
275-
patch_torch=patch_torch,
276-
patch_transformers=patch_transformers,
277-
patch_diffusers=patch_diffusers,
278-
catch_constraints=catch_constraints,
279-
stop_if_static=stop_if_static,
280-
verbose=verbose,
281-
patch=patch,
282-
custom_patches=custom_patches,
283-
) as f:
271+
with (
272+
torch_export_rewrite(
273+
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
274+
),
275+
torch_export_patches( # type: ignore[var-annotated]
276+
patch_sympy=patch_sympy,
277+
patch_torch=patch_torch,
278+
patch_transformers=patch_transformers,
279+
patch_diffusers=patch_diffusers,
280+
catch_constraints=catch_constraints,
281+
stop_if_static=stop_if_static,
282+
verbose=verbose,
283+
patch=patch,
284+
custom_patches=custom_patches,
285+
) as f,
286+
):
284287
try:
285288
yield f
286289
finally:

pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
[project]
2+
name = "onnx-diagnostic"
3+
version = "0.7.13"
4+
description = "Tools to help converting pytorch models into ONNX."
5+
readme = "README.rst"
6+
authors = [
7+
{ name = "Xavier Dupré", email = "[email protected]" }
8+
]
9+
license = { text = "MIT" }
10+
requires-python = ">=3.9"
11+
dependencies = []
12+
13+
[project.urls]
14+
Homepage = "https://sdpython.github.io/doc/onnx-diagnostic/dev/"
15+
Repository = "https://github.com/sdpython/onnx-diagnostic/"
16+
17+
[tool.setuptools.dynamic]
18+
dependencies = {file = "requirements.txt"}
19+
120
[tool.setuptools.package-data]
221
onnx_diagnostic = ["tasks/data/*.onnx"]
322

0 commit comments

Comments
 (0)