22from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
33import numpy as np
44import torch
5- from ..helpers import string_type , flatten_object
5+ from ..helpers import string_type
66from ..helpers .cache_helper import flatten_unflatten_for_dynamic_shapes
7- from ..helpers .fake_tensor_helper import make_fake
87
98DYNAMIC_SHAPES = Tuple [Tuple [Any , ...], Dict [str , Any ]]
109
1110
12- def flatten_dynamic_shapes (ds : Any ) -> Any :
11+ def _flatten_dynamic_shapes (ds : Any ) -> Any :
1312 """Flattens the dynamic shapes."""
1413 if isinstance (ds , list ):
15- return _flat_list ([flatten_dynamic_shapes (t ) for t in ds ])
14+ return _flat_list ([_flatten_dynamic_shapes (t ) for t in ds ])
1615 if isinstance (ds , tuple ):
17- return tuple (_flat_list ([flatten_dynamic_shapes (t ) for t in ds ]))
16+ return tuple (_flat_list ([_flatten_dynamic_shapes (t ) for t in ds ]))
1817 if isinstance (ds , dict ):
1918 if all (isinstance (i , int ) for i in ds ):
2019 # That's a dynamic shape
2120 return ds
22- return _flat_list ([flatten_dynamic_shapes (t ) for t in ds .values ()])
21+ return _flat_list ([_flatten_dynamic_shapes (t ) for t in ds .values ()])
2322 raise AssertionError (f"Not implemented for { type (ds )} : { ds } " )
2423
2524
@@ -33,51 +32,6 @@ def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
3332 return res
3433
3534
36- def make_fake_with_dynamic_dimensions (
37- x : Optional [Any ],
38- dynamic_shapes : Any ,
39- fake_mode : Optional ["FakeTensorMode" ] = None , # noqa: F821
40- ) -> Optional [Tuple ["FakeTensor" , "FaleTensorMode" ]]: # noqa: F821
41- """
42- Replaces all tensors by fake tensor respecting the same
43- constraints as the following dynamic shapes.
44- This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
45-
46- .. runpython::
47- :showcode:
48-
49- from onnx_diagnostic.export.dynamic_shapes import make_fake_with_dynamic_dimensions
50-
51- inputs, _ = make_fake_with_dynamic_dimensions(
52- dict(
53- input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
54- attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
55- position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
56- past_key_values=make_dynamic_cache(
57- [
58- (
59- torch.rand((2, 32, 30, 96), dtype=torch.float16),
60- torch.rand((2, 32, 30, 96), dtype=torch.float16),
61- ),
62- (
63- torch.rand((2, 32, 30, 96), dtype=torch.float16),
64- torch.rand((2, 32, 30, 96), dtype=torch.float16),
65- ),
66- ]
67- ),
68- )
69- )
70- print(inputs)
71- """
72- fake_inputs = make_fake (x , fake_mode = fake_mode )
73- flat_inputs = flatten_object (fake_inputs , drop_keys = True )
74- flat_ds = flatten_dynamic_shapes (dynamic_shapes )
75- assert len (flat_inputs ) == len (flat_ds ), (
76- f"Mismatch between the number of input tensor { len (flat_inputs )} "
77- f"and the number of dynamic_shapes { len (flat_ds )} "
78- )
79-
80-
8135class CoupleInputsDynamicShapes :
8236 """
8337 Pair inputs / dynamic shapes.
@@ -426,7 +380,7 @@ def _generic_walker_step(
426380 flat , spec = torch .utils ._pytree .tree_flatten (inputs )
427381 if all (isinstance (t , torch .Tensor ) for t in flat ):
428382 # We need to flatten dynamic shapes as well
429- ds = flatten_dynamic_shapes (ds )
383+ ds = _flatten_dynamic_shapes (ds )
430384 res = cls ._generic_walker_step (
431385 processor , flat , ds , flatten_unflatten = flatten_unflatten
432386 )
0 commit comments