Skip to content

Commit 6b2b993

Browse files
committed
fix doc
1 parent 1452052 commit 6b2b993

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def linkcode_resolve(domain, info):
233233
"onnx-script": "https://github.com/microsoft/onnxscript",
234234
"onnxscript": "https://github.com/microsoft/onnxscript",
235235
"onnxscript Tutorial": "https://microsoft.github.io/onnxscript/tutorial/index.html",
236+
"optree": "https://github.com/metaopt/optree",
236237
"Pattern-based Rewrite Using Rules With onnxscript": "https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html",
237238
"opsets": "https://onnx.ai/onnx/intro/concepts.html#what-is-an-opset-version",
238239
"pyinstrument": "https://pyinstrument.readthedocs.io/en/latest/",

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
JSON returns list when the original dynamic shapes are list or tuple
33
====================================================================
44
5+
Dynamic shapes given to :func:`torch.export.export` must follow the
6+
same semantic. What if we confuse tuple and list when defining the dynamic shapes,
7+
how to restore the expected type assuming we know the inputs?
8+
Not often useful but maybe we will learn more about
9+
:epkg:`optree`.
10+
511
Dynamic Shapes After JSON
612
+++++++++++++++++++++++++
13+
14+
JSON format does not make the difference between a list and a tuple.
15+
So after serializing to json and restoring, both of them become lists.
716
"""
817

918
import json
@@ -49,8 +58,8 @@
4958
# %%
5059
# tuple are replaced by list.
5160

52-
# The trick
53-
# +++++++++
61+
# The trick to restore tuple when expected
62+
# ++++++++++++++++++++++++++++++++++++++++
5463

5564

5665
def flatten_unflatten_like_dynamic_shapes(obj):

_unittests/ut_export/test_shape_helper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@ class TestShapeHelper(ExtTestCase):
1010
@requires_torch("2.7.99")
1111
def test_all_dynamic_shape_from_inputs(self):
1212
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
13+
self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds)
14+
ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
1315
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
1416
ds = all_dynamic_shape_from_inputs(
1517
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
1618
)
1719
self.assertEqual(
18-
[
20+
(
1921
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
2022
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
21-
],
23+
),
2224
ds,
2325
)
2426

2527
@requires_transformers("4.52")
2628
@requires_torch("2.7.99")
2729
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
2830
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
29-
print(self.string_type(data["inputs"], with_shape=True))
3031
ds = all_dynamic_shape_from_inputs(data["inputs"])
3132
self.assertEqual(
3233
{

0 commit comments

Comments
 (0)