Skip to content

Commit 1452052

Browse files
committed
json ambiguities
1 parent d2a78c4 commit 1452052

File tree

2 files changed

+111
-4
lines changed

2 files changed

+111
-4
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
JSON returns list when the original dynamic shapes are list or tuple
3+
====================================================================
4+
5+
Dynamic Shapes After JSON
6+
+++++++++++++++++++++++++
7+
"""
8+
9+
import json
10+
import pprint
11+
import torch
12+
from onnx_diagnostic import doc
13+
from onnx_diagnostic.helpers import string_type
14+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
15+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
16+
17+
bsize, nheads, slen, dim = 2, 1, 30, 96
18+
19+
inputs = dict(
20+
input_mask_position=(
21+
torch.randint(15, size=(2, 3), dtype=torch.int64),
22+
torch.randint(1, size=(2, 33), dtype=torch.int64),
23+
torch.arange(3, dtype=torch.int64),
24+
),
25+
past_key_values=make_dynamic_cache(
26+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
27+
),
28+
)
29+
30+
print(string_type(inputs, with_shape=True))
31+
32+
# %%
33+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
34+
# produces the corresponding dynamic shapes assuming they are all dynamic.
35+
ds = all_dynamic_shape_from_inputs(inputs)
36+
pprint.pprint(ds)
37+
38+
# %%
39+
# Converted into JSON.
40+
41+
json_str = json.dumps(ds, indent=2, ensure_ascii=False)
42+
print(json_str)
43+
44+
# %%
45+
# Restoration.
46+
ds2 = json.loads(json_str)
47+
pprint.pprint(ds2)
48+
49+
# %%
50+
# tuple are replaced by list.
51+
52+
# The trick
53+
# +++++++++
54+
55+
56+
def flatten_unflatten_like_dynamic_shapes(obj):
57+
if isinstance(obj, torch.Tensor):
58+
return obj
59+
flat, spec = torch.utils._pytree.tree_flatten(obj)
60+
start = 0
61+
end = 0
62+
subtrees = []
63+
for subspec in spec.children_specs:
64+
end += subspec.num_leaves
65+
value = subspec.unflatten(flat[start:end])
66+
value = flatten_unflatten_like_dynamic_shapes(value)
67+
subtrees.append(value)
68+
start = end
69+
if spec.type is dict or spec.context:
70+
return dict(zip(spec.context, subtrees))
71+
if spec.type is tuple:
72+
return tuple(subtrees)
73+
return subtrees
74+
75+
76+
def _align(inputs, ds):
77+
if isinstance(inputs, torch.Tensor):
78+
return ds
79+
if isinstance(inputs, tuple):
80+
return tuple(_align(o, d) for o, d in zip(inputs, ds))
81+
if isinstance(inputs, list):
82+
return [_align(o, d) for o, d in zip(inputs, ds)]
83+
if isinstance(inputs, dict):
84+
return {k: _align(inputs[k], d) for k, d in ds.items()}
85+
raise TypeError(f"Unexpected types inputs is {type(inputs)}, ds is {type(ds)}")
86+
87+
88+
def fix_dynamic_shapes(inputs, dynamic_shapes):
89+
flat_unflat_inputs = flatten_unflatten_like_dynamic_shapes(inputs)
90+
return _align(flat_unflat_inputs, dynamic_shapes)
91+
92+
93+
fixed_ds = fix_dynamic_shapes(inputs, ds2)
94+
pprint.pprint(fixed_ds)
95+
96+
# %%
97+
# The code changed tuple into list as expected.
98+
assert isinstance(ds2["input_mask_position"], list)
99+
assert isinstance(fixed_ds["input_mask_position"], tuple)
100+
101+
102+
# %%
103+
104+
doc.plot_legend("dynamic shapes\nto json\nfrom json", "torch.export.export", "green")

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def flatten_unflatten_for_dynamic_shapes(
1919
:func:`torch.export.export` only considers the values,
2020
the context gives the dictionary keys but it is not expressed
2121
in the dynamic shapes, these specifications seems to be different
22-
for the strict and non strict mode.
22+
for the strict and non strict mode. It also preserves tuple.
2323
:param change_function: to modifies the tensor in the structure itself,
2424
like replace them by a shape
2525
:return: the serialized object
@@ -38,9 +38,12 @@ def flatten_unflatten_for_dynamic_shapes(
3838
)
3939
subtrees.append(value)
4040
start = end
41-
if use_dict and (spec.type is dict or spec.context):
42-
# This a dictionary.
43-
return dict(zip(spec.context, subtrees))
41+
if use_dict:
42+
if spec.type is dict or spec.context:
43+
# This a dictionary.
44+
return dict(zip(spec.context, subtrees))
45+
if spec.type is tuple:
46+
return tuple(subtrees)
4447
# This is a list.
4548
return subtrees
4649

0 commit comments

Comments
 (0)