Skip to content

Commit dc216f4

Browse files
committed
fix test case
1 parent 931fe47 commit dc216f4

File tree

4 files changed

+125
-44
lines changed

4 files changed

+125
-44
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ temp_dump_models/*
4343
dump_dort_bench/*
4444
test_zoo_*/
4545
temp_llama_model_*/*
46+
test_example_transform_method_*.py
4647
test_*.txt
4748
_cache/*
4849
.coverage

_doc/status/exported_program_dynamic.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ to the original model.
2929
for name, cls_model in sorted_cases:
3030
print(f"* :ref:`{name} <led-model-case-export-{name}>`")
3131
print()
32+
print()
3233

3334
obs = []
3435
for name, cls_model in sorted(cases.items()):
@@ -41,14 +42,15 @@ to the original model.
4142
print("forward")
4243
print("+++++++")
4344
print()
44-
print("::")
45+
print(".. code-block:: python")
4546
print()
4647
src = inspect.getsource(cls_model.forward)
4748
if src:
4849
print(textwrap.indent(textwrap.dedent(src), " "))
4950
else:
5051
print(" # code is missing")
5152
print()
53+
print()
5254
for exporter in (
5355
"export-strict",
5456
"export-nostrict",
@@ -67,23 +69,26 @@ to the original model.
6769
if "dynamic_shapes" in res:
6870
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
6971
print()
72+
print()
7073
if "exported" in res:
71-
print("::")
74+
print(".. code-block:: text")
7275
print()
7376
print(textwrap.indent(str(res["exported"].graph), " "))
7477
print()
78+
print()
7579
obs.append(dict(case=case_ref, error="", exporter=expo))
7680
else:
7781
print("**FAILED**")
7882
print()
79-
print("::")
83+
print(".. code-block:: text")
8084
print()
8185
err = str(res["error"])
8286
if err:
8387
print(textwrap.indent(err, " "))
8488
else:
8589
print(" # no error found for the failure")
8690
print()
91+
print()
8792
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
8893

8994
print()

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 106 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,17 @@ def _make_exporter_export(
185185

186186
if exporter == "export-strict":
187187
try:
188-
exported = torch.export.export(
189-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
190-
)
188+
if verbose >= 2:
189+
exported = torch.export.export(
190+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191+
)
192+
else:
193+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194+
io.StringIO()
195+
):
196+
exported = torch.export.export(
197+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198+
)
191199
except Exception as e:
192200
if not quiet:
193201
raise
@@ -198,17 +206,33 @@ def _make_exporter_export(
198206
return exported.module()
199207
if exporter in ("export-strict-dec", "export-strict-decall"):
200208
try:
201-
exported = torch.export.export(
202-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
203-
)
204-
if verbose >= 9:
205-
print("-- graph before decomposition")
206-
print(exported.graph)
207-
exported = (
208-
exported.run_decompositions()
209-
if "decall" in exporter
210-
else exported.run_decompositions({})
211-
)
209+
if verbose >= 2:
210+
exported = torch.export.export(
211+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212+
)
213+
if verbose >= 9:
214+
print("-- graph before decomposition")
215+
print(exported.graph)
216+
exported = (
217+
exported.run_decompositions()
218+
if "decall" in exporter
219+
else exported.run_decompositions({})
220+
)
221+
else:
222+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223+
io.StringIO()
224+
):
225+
exported = torch.export.export(
226+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227+
)
228+
if verbose >= 9:
229+
print("-- graph before decomposition")
230+
print(exported.graph)
231+
exported = (
232+
exported.run_decompositions()
233+
if "decall" in exporter
234+
else exported.run_decompositions({})
235+
)
212236
except Exception as e:
213237
if not quiet:
214238
raise
@@ -219,9 +243,17 @@ def _make_exporter_export(
219243
return exported.module()
220244
if exporter == "export-nostrict":
221245
try:
222-
exported = torch.export.export(
223-
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
224-
)
246+
if verbose >= 2:
247+
exported = torch.export.export(
248+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
249+
)
250+
else:
251+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252+
io.StringIO()
253+
):
254+
exported = torch.export.export(
255+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
256+
)
225257
except Exception as e:
226258
if not quiet:
227259
raise
@@ -232,17 +264,33 @@ def _make_exporter_export(
232264
return exported.module()
233265
if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
234266
try:
235-
exported = torch.export.export(
236-
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
237-
)
238-
if verbose >= 9:
239-
print("-- graph before decomposition")
240-
print(exported.graph)
241-
exported = (
242-
exported.run_decompositions()
243-
if "decall" in exporter
244-
else exported.run_decompositions({})
245-
)
267+
if verbose >= 2:
268+
exported = torch.export.export(
269+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
270+
)
271+
if verbose >= 9:
272+
print("-- graph before decomposition")
273+
print(exported.graph)
274+
exported = (
275+
exported.run_decompositions()
276+
if "decall" in exporter
277+
else exported.run_decompositions({})
278+
)
279+
else:
280+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281+
io.StringIO()
282+
):
283+
exported = torch.export.export(
284+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
285+
)
286+
if verbose >= 9:
287+
print("-- graph before decomposition")
288+
print(exported.graph)
289+
exported = (
290+
exported.run_decompositions()
291+
if "decall" in exporter
292+
else exported.run_decompositions({})
293+
)
246294
except Exception as e:
247295
if not quiet:
248296
raise
@@ -255,8 +303,15 @@ def _make_exporter_export(
255303
from experimental_experiment.torch_interpreter.tracing import CustomTracer
256304

257305
try:
258-
graph = CustomTracer().trace(model)
259-
mod = torch.fx.GraphModule(model, graph)
306+
if verbose >= 2:
307+
graph = CustomTracer().trace(model)
308+
mod = torch.fx.GraphModule(model, graph)
309+
else:
310+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311+
io.StringIO()
312+
):
313+
graph = CustomTracer().trace(model)
314+
mod = torch.fx.GraphModule(model, graph)
260315
except Exception as e:
261316
if not quiet:
262317
raise
@@ -289,13 +344,25 @@ def _make_exporter_onnx(
289344
if "-dec" in exporter:
290345
opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
291346
try:
292-
onx, builder = to_onnx(
293-
model,
294-
inputs,
295-
dynamic_shapes=dynamic_shapes,
296-
export_options=ExportOptions(**opts),
297-
return_builder=True,
298-
)
347+
if verbose >= 2:
348+
onx, builder = to_onnx(
349+
model,
350+
inputs,
351+
dynamic_shapes=dynamic_shapes,
352+
export_options=ExportOptions(**opts),
353+
return_builder=True,
354+
)
355+
else:
356+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357+
io.StringIO()
358+
):
359+
onx, builder = to_onnx(
360+
model,
361+
inputs,
362+
dynamic_shapes=dynamic_shapes,
363+
export_options=ExportOptions(**opts),
364+
return_builder=True,
365+
)
299366
except Exception as e:
300367
if not quiet:
301368
raise RuntimeError(
@@ -306,6 +373,7 @@ def _make_exporter_onnx(
306373
) from e
307374
return dict(error=str(e), success=0, error_step="export")
308375
return onx, builder
376+
309377
if exporter == "dynamo":
310378
import torch
311379

@@ -338,6 +406,7 @@ def _make_exporter_onnx(
338406
) from e
339407
return dict(error=str(e), success=0, error_step="export")
340408
return onx, None
409+
341410
if exporter == "dynamo-ir":
342411
import torch
343412

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(self, x):
143143
return x
144144

145145
_inputs = [(torch.randn((2, 3, 3)),), (torch.randn((3, 3, 3)),)]
146-
_dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
146+
_dynamic = {"x": {0: DIM("batch")}}
147147

148148

149149
class AtenInterpolate(torch.nn.Module):
@@ -353,7 +353,9 @@ def else_branch(input_ids, image_features, vocab_size):
353353

354354

355355
class ControlFlowCondIdentity_153832(torch.nn.Module):
356-
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
356+
"""
357+
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358+
"""
357359

358360
def forward(self, x, y):
359361

@@ -501,7 +503,9 @@ def forward(self, x, y):
501503

502504

503505
class ControlFlowScanInplace_153705(torch.nn.Module):
504-
"""`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_"""
506+
"""
507+
`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_
508+
"""
505509

506510
def forward(self, x, y):
507511
def loop_body_1(z, iv, x, y):
@@ -524,7 +528,9 @@ def loop_body_1(z, iv, x, y):
524528

525529

526530
class ControlFlowScanDecomposition_151564(torch.nn.Module):
527-
"""`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_"""
531+
"""
532+
`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_
533+
"""
528534

529535
@classmethod
530536
def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):

0 commit comments

Comments
 (0)