diff --git a/_doc/cmds/validate.rst b/_doc/cmds/validate.rst index e09ff699..e27df801 100644 --- a/_doc/cmds/validate.rst +++ b/_doc/cmds/validate.rst @@ -100,6 +100,7 @@ Let's export with ONNX this time and checks for discrepancies. python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir .. runpython:: + :process: from onnx_diagnostic._command_lines_parser import main @@ -117,6 +118,7 @@ of function :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`. python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL .. runpython:: + :process: from onnx_diagnostic._command_lines_parser import main diff --git a/_doc/examples/plot_dump_intermediate_results.py b/_doc/examples/plot_dump_intermediate_results.py index df61dd05..9fe760a3 100644 --- a/_doc/examples/plot_dump_intermediate_results.py +++ b/_doc/examples/plot_dump_intermediate_results.py @@ -128,7 +128,8 @@ # results to intermediate results in ONNX. # Let's create the ONNX model. -epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True) +ep = torch.export.export(model, inputs, dynamic_shapes=ds) +epo = torch.onnx.export(ep, dynamo=True) epo.optimize() epo.save("plot_dump_intermediate_results.onnx") diff --git a/_doc/recipes/plot_export_dim1.py b/_doc/recipes/plot_export_dim1.py index 20a85abb..3b9c2f1a 100644 --- a/_doc/recipes/plot_export_dim1.py +++ b/_doc/recipes/plot_export_dim1.py @@ -54,9 +54,12 @@ def forward(self, x, y, z): # Same model, a dynamic dimension = 1 and backed_size_oblivious=True # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -with torch.fx.experimental._config.patch(backed_size_oblivious=True): - ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds)) - print(ep.graph) +try: + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds)) + print(ep.graph) +except RuntimeError as e: + print("ERROR", e) # %% # It worked. diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 3d37425b..9e993cc5 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -96,13 +96,16 @@ def flatten_unflatten_for_dynamic_shapes( return tuple(subtrees) if spec.type is list: return list(subtrees) + if spec.type is None and not subtrees: + return None if spec.context: # This is a custom class with attributes. # It is returned as a list. return list(subtrees) raise ValueError( f"Unable to interpret spec type {spec.type} " - f"(type is {type(spec.type)}, context is {spec.context})." + f"(type is {type(spec.type)}, context is {spec.context}), " + f"spec={spec}, subtrees={subtrees}" ) # This is a list. return subtrees