Skip to content

Commit 28f3515

Browse files
committed
Changes Cache serialization
1 parent 7979496 commit 28f3515

File tree

4 files changed

+177
-117
lines changed

4 files changed

+177
-117
lines changed

_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
make_static_cache,
99
make_sliding_window_cache,
1010
flatten_unflatten_for_dynamic_shapes,
11+
make_dynamic_shapes_kv_cache,
1112
CacheKeyValue,
1213
)
1314
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -64,8 +65,8 @@ def forward(self, cache):
6465
model(cache)
6566
DYN = torch.export.Dim.DYNAMIC
6667
ds = [
67-
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
68-
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
68+
make_dynamic_shapes_kv_cache(cache1, {0: DYN}),
69+
make_dynamic_shapes_kv_cache(cache2, {0: DYN}),
6970
]
7071

7172
with torch_export_patches(patch_transformers=True):
@@ -99,9 +100,15 @@ def forward(self, cache):
99100
model = Model()
100101
model(cache)
101102
DYN = torch.export.Dim.DYNAMIC
102-
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
103+
ds = make_dynamic_shapes_kv_cache(cache, {0: DYN})
104+
self.assertEqual(len(ds), 6)
103105

104-
with torch_export_patches():
106+
with torch_export_patches(patch_transformers=True):
107+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
108+
self.assertEqual(len(flat), len(ds))
109+
unflat = torch.utils._pytree.tree_unflatten(flat, _spec)
110+
if hasattr(unflat, "layers"):
111+
self.assertEqual(len(unflat.layers), 3)
105112
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
106113

107114
@ignore_warnings(UserWarning)
@@ -195,7 +202,7 @@ def forward(self, cache):
195202
model = Model()
196203
model(cache)
197204
DYN = torch.export.Dim.DYNAMIC
198-
ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]
205+
ds = make_dynamic_shapes_kv_cache(cache, {0: DYN})
199206

200207
with torch_export_patches(patch_transformers=True):
201208
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
@@ -265,9 +272,7 @@ def test_static_cache(self):
265272
flat, _spec = torch.utils._pytree.tree_flatten(bo)
266273
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
267274
self.assertIsInstance(unflat, list)
268-
self.assertEqual(
269-
"#2[#3[T1r4,T1r4,T1r4],#3[T1r4,T1r4,T1r4]]", self.string_type(unflat)
270-
)
275+
self.assertEqual("#6[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]", self.string_type(unflat))
271276

272277
# export
273278
class Model(torch.nn.Module):
@@ -278,7 +283,7 @@ def forward(self, cache):
278283
model = Model()
279284
model(bo)
280285
DYN = torch.export.Dim.DYNAMIC
281-
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
286+
ds = make_dynamic_shapes_kv_cache(bo, {0: DYN})
282287

283288
with torch_export_patches(patch_transformers=True, stop_if_static=1):
284289
torch.export.export(model, (bo,), dynamic_shapes=(ds,))

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, List, Optional, Tuple
1+
from typing import Any, Callable, Dict, List, Optional, Tuple
22
import packaging.version as pv
33
import torch
44
import transformers
@@ -46,9 +46,14 @@ def __init__(self, cache=None):
4646
raise NotImplementedError(f"type(cache)={type(cache)}")
4747

4848
def make_dynamic_cache(self):
49-
"""Do the reverse operation."""
49+
"""Does the reverse operation."""
5050
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
5151

52+
@property
53+
def n_layers(self) -> int:
54+
"""Returns the number of layers."""
55+
return len(self.key_cache) if self.key_cache else 0
56+
5257

5358
def flatten_unflatten_for_dynamic_shapes(
5459
obj: Any,
@@ -134,6 +139,19 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
134139
return len(cache2.key_cache) == len(cache.value_cache)
135140

136141

142+
def make_dynamic_shapes_kv_cache(
143+
cache: transformers.cache_utils.Cache, shape_of_one: Dict[str, Any]
144+
) -> List[Dict[int, Any]]:
145+
"""
146+
Returns the dynamic shapes for key-value cache
147+
148+
:param cache: a cache
149+
:param shape_of_one: shape of one element
150+
:return: dynamic shapes
151+
"""
152+
return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
153+
154+
137155
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
138156

139157
def make_dynamic_cache(

onnx_diagnostic/helpers/helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
22
import enum
33
import inspect
4+
import itertools
45
from dataclasses import is_dataclass, fields
56
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
67
import numpy as np
@@ -948,8 +949,8 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
948949
from .cache_helper import CacheKeyValue
949950

950951
kc = CacheKeyValue(x)
951-
res = flatten_object(kc.key_cache) + flatten_object(kc.value_cache)
952-
return tuple(res)
952+
return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
953+
953954
if x.__class__.__name__ == "EncoderDecoderCache":
954955
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
955956
return tuple(res)

0 commit comments

Comments
 (0)