Skip to content

Commit 1d7c7f7

Browse files
committed
Enable strings in guess_dynamic_shapes
1 parent 750b6dd commit 1d7c7f7

File tree

4 files changed

+124
-21
lines changed

4 files changed

+124
-21
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Enlightening Examples
6767
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
6868
* `Find and fix an export issue due to dynamic shapes
6969
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
70-
* `Export with DynamicCache and dynamic shapes
70+
* `Export with DynamicCache and guessed dynamic shapes
7171
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
7272
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
7373
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
.. _l-plot-export-with-dynamic-shape:
33
4-
===========================================
5-
Export with DynamicCache and dynamic shapes
6-
===========================================
4+
===================================================
5+
Export with DynamicCache and guessed dynamic shapes
6+
===================================================
77
88
Every LLMs implemented in :epkg:`transformers` use cache.
99
One of the most used is :class:`transformers.cache_utils.DynamicCache`.
@@ -84,6 +84,8 @@ def forward(self, cache, z):
8484
print(string_type(inputs[1], with_shape=True))
8585

8686
# %%
87+
# .. _l-guess-dynamic-shapes-example:
88+
#
8789
# Guess the dynamic shapes
8890
# ========================
8991
#
@@ -112,6 +114,17 @@ def forward(self, cache, z):
112114
)
113115
print(ep)
114116

117+
# %%
118+
# Use string instead of DYNAMIC
119+
# +++++++++++++++++++++++++++++
120+
#
121+
# ONNX exporter considers strings instead of DYNAMIC or AUTO
122+
# to give names to every dimension.
123+
124+
dss = mi.guess_dynamic_shapes(auto="dim")
125+
pprint.pprint(dss)
126+
127+
115128
# %%
116129
# Do we need to guess?
117130
# ++++++++++++++++++++

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,65 @@ def forward(self, cache, z):
470470
),
471471
)
472472

473+
def test_guess_dynamic_shapes_cache_str(self):
474+
class Model(torch.nn.Module):
475+
def forward(self, cache, z):
476+
return (
477+
z
478+
+ cache.key_cache[0]
479+
+ cache.key_cache[1]
480+
+ cache.value_cache[0]
481+
+ cache.value_cache[1]
482+
)
483+
484+
model = Model()
485+
486+
n_layers = 2
487+
bsize, nheads, slen, dim = 2, 4, 3, 7
488+
cache = make_dynamic_cache(
489+
[
490+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
491+
for i in range(n_layers)
492+
]
493+
)
494+
z = torch.randn((1, 1, 1, 7))
495+
model(cache, z)
496+
497+
cache2 = make_dynamic_cache(
498+
[
499+
(
500+
torch.randn(bsize, nheads, slen, dim),
501+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
502+
)
503+
for i in range(n_layers)
504+
]
505+
)
506+
inputs = [
507+
(cache, z),
508+
(cache2, torch.randn((1, 1, 1, 8))),
509+
]
510+
511+
mi = ModelInputs(Model(), inputs)
512+
self.assertIn("DynamicCache", string_type(mi.inputs, with_shape=True))
513+
ds = mi.guess_dynamic_shapes(auto="dim")
514+
print(ds)
515+
self.assertEqual(
516+
ds,
517+
(
518+
(
519+
[
520+
[{}, {}],
521+
[
522+
{0: "dim_0I_1o_0l0", 2: "dim_0I_1o_0l2", 3: "dim_0I_1o_0l3"},
523+
{0: "dim_0I_1o_1l0", 2: "dim_0I_1o_1l2", 3: "dim_0I_1o_1l3"},
524+
],
525+
],
526+
{3: "dim_1I3"},
527+
),
528+
{},
529+
),
530+
)
531+
473532
def test_couple_input_ds_0(self):
474533
T3x4 = torch.rand((3, 4))
475534
T3x1 = torch.rand((3, 1))

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -723,13 +723,14 @@ def module_name_type(self):
723723
return f"type({self.name})={self.true_model_name}.{self.method_name}"
724724

725725
def guess_dynamic_dimensions(
726-
self, *tensors, auto: bool = False
726+
self, *tensors, auto: Union[bool, str] = False
727727
) -> Optional[Dict[int, Any]]:
728728
"""
729729
Infers the dynamic dimension from multiple shapes.
730730
If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
731731
which cannot be guessed. Two tensors with the same value for one dimension
732-
can be guessed, but if there is only 1, it cannot.
732+
can be guessed, but if there is only 1, it cannot. ``auto``` can be a string
733+
to produce strings.
733734
"""
734735
if len(tensors) == 1:
735736
if isinstance(tensors[0], (int, float)):
@@ -740,7 +741,7 @@ def guess_dynamic_dimensions(
740741
)
741742
return (
742743
{i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
743-
if auto
744+
if auto and not isinstance(auto, str)
744745
else {}
745746
)
746747
shapes = [t.shape for t in tensors]
@@ -750,22 +751,26 @@ def guess_dynamic_dimensions(
750751
f"shapes={shapes} for module {self.name!r}, "
751752
f"class={self.true_model_name!r}"
752753
)
753-
dynamic: Any = torch.export.Dim.DYNAMIC # type: ignore
754+
dynamic: Any = (
755+
auto
756+
if isinstance(auto, str)
757+
else (torch.export.Dim.AUTO if auto else torch.export.Dim.DYNAMIC)
758+
)
754759
rk = set_length.pop()
755760
res = {}
756761
for i in range(rk):
757762
set_dim = set(s[i] for s in shapes)
758763
if len(set_dim) > 1:
759-
res[i] = dynamic
764+
res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
760765
continue
761766
if set_dim == {0}:
762767
# It is unexpected to find a null dimension. Let's replace it by a dynamic one.
763-
res[i] = dynamic
768+
res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
764769
continue
765770
return res
766771

767772
def guess_dynamic_shape_object(
768-
self, *objs: Any, auto: bool = False, msg: Optional[Callable] = None
773+
self, *objs: Any, auto: Union[bool, str] = False, msg: Optional[Callable] = None
769774
) -> Any:
770775
"""Guesses the dynamic shapes for one argument."""
771776
if len(objs) == 0:
@@ -790,7 +795,11 @@ def guess_dynamic_shape_object(
790795
shapes: Any = []
791796
for i in range(kl.pop()):
792797
shapes.append(
793-
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
798+
self.guess_dynamic_shape_object(
799+
*[o[i] for o in objs],
800+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}t",
801+
msg=msg,
802+
)
794803
)
795804
return tuple(shapes)
796805

@@ -802,7 +811,11 @@ def guess_dynamic_shape_object(
802811
shapes = []
803812
for i in range(kl.pop()):
804813
shapes.append(
805-
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
814+
self.guess_dynamic_shape_object(
815+
*[o[i] for o in objs],
816+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}l",
817+
msg=msg,
818+
)
806819
)
807820
return shapes
808821

@@ -814,7 +827,9 @@ def guess_dynamic_shape_object(
814827
shapes = {}
815828
for i in obj:
816829
shapes[i] = self.guess_dynamic_shape_object(
817-
*[o[i] for o in objs], auto=auto, msg=msg
830+
*[o[i] for o in objs],
831+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}d",
832+
msg=msg,
818833
)
819834
return shapes
820835

@@ -834,7 +849,9 @@ def guess_dynamic_shape_object(
834849
for i in range(kc.pop()):
835850
values.append(
836851
self.guess_dynamic_shape_object(
837-
*[ca[i] for ca in col_args], auto=auto, msg=msg
852+
*[ca[i] for ca in col_args],
853+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}o",
854+
msg=msg,
838855
)
839856
)
840857
return values
@@ -852,12 +869,18 @@ def guess_dynamic_shape_object(
852869
key_cache = []
853870
for i in range(kc.pop()):
854871
key_cache.append(
855-
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs], auto=auto)
872+
self.guess_dynamic_dimensions(
873+
*[o.key_cache[i] for o in objs],
874+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
875+
)
856876
)
857877
value_cache = []
858878
for i in range(vc.pop()):
859879
value_cache.append(
860-
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs], auto=auto)
880+
self.guess_dynamic_dimensions(
881+
*[o.value_cache[i] for o in objs],
882+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
883+
)
861884
)
862885
return [key_cache, value_cache]
863886

@@ -867,13 +890,17 @@ def guess_dynamic_shape_object(
867890
f"this object needs serialization function to be registered."
868891
)
869892

870-
def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
893+
def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES:
871894
"""
872895
Guesses the dynamic shapes for that module from two execution.
873896
If there is only one execution, then that would be static dimensions.
874897
875898
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
876-
dimension if the number of inputs is one
899+
dimension if the number of inputs is one,
900+
if ``auto`` is a string, it uses strings
901+
:return: guessed dynamic shapes
902+
903+
See example :ref:`l-guess-dynamic-shapes-example`.
877904
"""
878905
if len(self.inputs) == 0:
879906
# No inputs, unable to guess.
@@ -900,7 +927,9 @@ def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
900927
objs = [_[0][i] for _ in self.inputs]
901928
args.append(
902929
self.guess_dynamic_shape_object(
903-
*objs, auto=auto, msg=lambda i=i: f" failing input {i}"
930+
*objs,
931+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
932+
msg=lambda i=i: f" failing input {i}",
904933
)
905934
)
906935
names = s2.pop()
@@ -913,7 +942,9 @@ def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
913942

914943
objs = [_[1][name] for _ in self.inputs]
915944
kwargs[name] = self.guess_dynamic_shape_object(
916-
*objs, auto=auto, msg=lambda name=name: f" failing input {name!r}"
945+
*objs,
946+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
947+
msg=lambda name=name: f" failing input {name!r}",
917948
)
918949
return tuple(args), kwargs
919950

0 commit comments

Comments
 (0)