Skip to content

Commit 2108a19

Browse files
committed
Improves documentation
1 parent 4a0781b commit 2108a19

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

_doc/examples/plot_dump_intermediate_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@
128128
# results to intermediate results in ONNX.
129129
# Let's create the ONNX model.
130130

131-
epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
131+
ep = torch.export.export(model, inputs, dynamic_shapes=ds)
132+
epo = torch.onnx.export(ep, dynamo=True)
132133
epo.optimize()
133134
epo.save("plot_dump_intermediate_results.onnx")
134135

_doc/recipes/plot_export_dim1.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ def forward(self, x, y, z):
5454
# Same model, a dynamic dimension = 1 and backed_size_oblivious=True
5555
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5656

57-
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
58-
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
59-
print(ep.graph)
57+
try:
58+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
59+
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
60+
print(ep.graph)
61+
except RuntimeError as e:
62+
print("ERROR", e)
6063

6164
# %%
6265
# It worked.

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,16 @@ def flatten_unflatten_for_dynamic_shapes(
9696
return tuple(subtrees)
9797
if spec.type is list:
9898
return list(subtrees)
99+
if spec.type is None and not subtrees:
100+
return None
99101
if spec.context:
100102
# This is a custom class with attributes.
101103
# It is returned as a list.
102104
return list(subtrees)
103105
raise ValueError(
104106
f"Unable to interpret spec type {spec.type} "
105-
f"(type is {type(spec.type)}, context is {spec.context})."
107+
f"(type is {type(spec.type)}, context is {spec.context}), "
108+
f"spec={spec}, subtrees={subtrees}"
106109
)
107110
# This is a list.
108111
return subtrees

0 commit comments

Comments
 (0)