@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
99 All dimensions are considered as dynamic.
1010 ``dim_prefix`` can be a string (the function uses it as a prefix),
1111 or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12+ Depending on the version of transformers, serializations function
13+ of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
1214
1315 .. runpython::
1416 :showcode:
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
1719 import torch
1820 from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
1921 from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22+ from onnx_diagnostic.torch_export_patches import torch_export_patches
2023
2124 bsize, nheads, slen, dim = 2, 1, 30, 96
2225 inputs = dict(
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
2528 position_ids=torch.arange(3, dtype=torch.int64),
2629 past_key_values=make_dynamic_cache(
2730 [(torch.randn(bsize, nheads, slen, dim),
28- torch.randn(bsize, nheads, slen, dim))]
31+ torch.randn(bsize, nheads, slen, dim))]
2932 ),
3033 )
31- ds = all_dynamic_shape_from_inputs(inputs)
34+ with torch_export_patches(patch_transformers=True):
35+ ds = all_dynamic_shape_from_inputs(inputs)
3236 pprint.pprint(ds)
3337
3438 For this function to work, patches must be enabled if :epkg:`transformers`
0 commit comments