Skip to content

Commit d5ea218

Browse files
committed
improves dynamic shapes handling
1 parent fb58b61 commit d5ea218

File tree

5 files changed

+349
-72
lines changed

5 files changed

+349
-72
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6+
from onnx_diagnostic.export import CoupleInputsDynamicShapes
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import (
8+
convert_dynamic_axes_into_dynamic_shapes,
9+
)
10+
11+
12+
class TestCacheHelpers(ExtTestCase):
13+
def test_string_type(self):
14+
DYN = torch.export.Dim.DYNAMIC
15+
self.assertEqual("DYNAMIC", string_type(DYN, verbose=0))
16+
AUTO = torch.export.Dim.AUTO
17+
self.assertEqual("AUTO", string_type(AUTO, verbose=0))
18+
self.assertEqual("#1[DYNAMIC]", string_type([DYN]))
19+
20+
batch = torch.export.Dim("batch")
21+
dynamic_shapes = dict(
22+
input_ids={0: batch, 1: "seq"},
23+
attention_mask={0: batch, 1: "seq"},
24+
position_ids={0: batch, 1: "seq"},
25+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
26+
)
27+
self.assertEqual(
28+
"dict(input_ids:{0:Dim(batch),1:DYN(seq)},"
29+
"attention_mask:{0:Dim(batch),1:DYN(seq)},"
30+
"position_ids:{0:Dim(batch),1:DYN(seq)},"
31+
"past_key_values:#2[#1[{0:Dim(batch),2:DYN(seq)}],"
32+
"#1[{0:Dim(batch),2:DYN(seq)}]])",
33+
string_type(dynamic_shapes),
34+
)
35+
36+
def test_replace_by(self):
37+
bsize, nheads, slen, dim = 2, 4, 3, 7
38+
39+
past_key_values = make_dynamic_cache(
40+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
41+
)
42+
kwargs = dict(
43+
input_ids=torch.zeros(2, 3),
44+
attention_mask=torch.zeros(2, 3),
45+
position_ids=torch.zeros(2, 3),
46+
past_key_values=past_key_values,
47+
)
48+
batch = torch.export.Dim("batch")
49+
dynamic_shapes = dict(
50+
input_ids={0: batch, 1: "seq"},
51+
attention_mask={0: batch, 1: "seq"},
52+
position_ids={0: batch, 1: "seq"},
53+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
54+
)
55+
56+
DYN = torch.export.Dim.DYNAMIC
57+
nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
58+
None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes
59+
)
60+
self.assertEqual(dynamic_shapes, nds)
61+
62+
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
63+
res = cpl.replace_string_by()
64+
dsc = res["past_key_values"]
65+
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)
66+
67+
68+
if __name__ == "__main__":
69+
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def replace_string_by(self, value: Any = None):
9292
return self._generic_walker(
9393
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
9494
inputs, ds, value=value
95-
)
95+
),
96+
flatten_unflatten=True,
9697
)
9798

9899
@classmethod
@@ -135,7 +136,8 @@ def replace_by_string(self):
135136
return self._generic_walker(
136137
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
137138
inputs, ds, unique=unique
138-
)
139+
),
140+
flatten_unflatten=True,
139141
)
140142

141143
@classmethod
@@ -203,7 +205,7 @@ def invalid_dimensions_for_export(self):
203205
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
204206
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
205207
"""
206-
return self._generic_walker(self._valid_shapes_tensor)
208+
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
207209

208210
@classmethod
209211
def _valid_shapes_tensor(cls, inputs, ds):
@@ -221,7 +223,9 @@ def _valid_shapes_tensor(cls, inputs, ds):
221223
issues[i] = f"d=[{d}]"
222224
return issues if issues else None
223225

224-
def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
226+
def _generic_walker(
227+
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
228+
):
225229
"""
226230
Generic deserializator walking through inputs and dynamic_shapes all along.
227231
The function returns a result with the same structure as the dynamic shapes.
@@ -231,15 +235,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
231235
f"Type mismatch, args={string_type(self.args)} and "
232236
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
233237
)
234-
res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
238+
res = self._generic_walker_step(
239+
processor,
240+
self.kwargs,
241+
self.dynamic_shapes,
242+
flatten_unflatten=flatten_unflatten,
243+
)
235244
return (tuple(), res) if args_kwargs else res
236245

237246
if not self.kwargs:
238247
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
239248
f"Type mismatch, args={string_type(self.args)} and "
240249
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
241250
)
242-
res = self._generic_walker_step(processor, self.args, self.dynamic_shapes)
251+
res = self._generic_walker_step(
252+
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
253+
)
243254
return (res, {}) if args_kwargs else res
244255

245256
assert isinstance(self.dynamic_shapes, dict), (
@@ -250,12 +261,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
250261
self.dynamic_shapes
251262
):
252263
# No dynamic shapes for the positional arguments.
253-
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
264+
return self._generic_walker_step(
265+
processor,
266+
self.kwargs,
267+
self.dynamic_shapes,
268+
flatten_unflatten=flatten_unflatten,
269+
)
254270

255271
if isinstance(self.args_names, list):
256272
if not set(self.args_names) & set(self.dynamic_shapes):
257273
# No dynamic shapes for the positional arguments.
258-
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
274+
return self._generic_walker_step(
275+
processor,
276+
self.kwargs,
277+
self.dynamic_shapes,
278+
flatten_unflatten=flatten_unflatten,
279+
)
259280

260281
assert self.args_names, (
261282
"args and kwargs are filled, then args_names must be specified in "
@@ -268,7 +289,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
268289
)
269290
kwargs = dict(zip(self.args_names, self.args))
270291
kwargs.update(self.kwargs)
271-
res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
292+
res = self._generic_walker_step(
293+
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
294+
)
272295
if args_kwargs:
273296
pgs = [None for _ in range(len(self.args))]
274297
kws = {}
@@ -286,7 +309,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
286309
)
287310

288311
@classmethod
289-
def _generic_walker_step(cls, processor: Callable, inputs, ds):
312+
def _generic_walker_step(
313+
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
314+
):
290315
if isinstance(inputs, torch.Tensor):
291316
return processor(inputs, ds)
292317
if isinstance(inputs, (int, float, str)):
@@ -303,7 +328,11 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
303328
if isinstance(inputs, (tuple, list)):
304329
value = []
305330
for i, d in zip(inputs, ds):
306-
value.append(cls._generic_walker_step(processor, i, d))
331+
value.append(
332+
cls._generic_walker_step(
333+
processor, i, d, flatten_unflatten=flatten_unflatten
334+
)
335+
)
307336
return (
308337
(value if isinstance(ds, list) else tuple(value))
309338
if any(v is not None for v in value)
@@ -314,7 +343,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
314343
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
315344
dvalue = {}
316345
for k, v in inputs.items():
317-
t = cls._generic_walker_step(processor, v, ds[k])
346+
t = cls._generic_walker_step(
347+
processor, v, ds[k], flatten_unflatten=flatten_unflatten
348+
)
318349
if t is not None:
319350
dvalue[k] = t
320351
return dvalue if dvalue else None
@@ -325,11 +356,18 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
325356
f"torch.utils._pytree.register_pytree_node, it is not possible to "
326357
f"map this class with the given dynamic shapes."
327358
)
359+
if flatten_unflatten:
360+
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
361+
return cls._generic_walker_step(
362+
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
363+
)
328364
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
329365
if all(isinstance(t, torch.Tensor) for t in flat):
330366
# We need to flatten dynamic shapes as well
331367
ds = flatten_dynamic_shapes(ds)
332-
return cls._generic_walker_step(processor, flat, ds)
368+
return cls._generic_walker_step(
369+
processor, flat, ds, flatten_unflatten=flatten_unflatten
370+
)
333371

334372
class ChangeDimensionProcessor:
335373
def __init__(self, desired_values):

0 commit comments

Comments
 (0)