Skip to content

Commit 11512cb

Browse files
committed
or
1 parent 35661f6 commit 11512cb

File tree

10 files changed

+273
-55
lines changed

10 files changed

+273
-55
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from onnx_diagnostic.ext_test_case import ExtTestCase
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.helpers.cache_helper import (
7+
flatten_unflatten_for_dynamic_shapes,
78
make_dynamic_cache,
89
make_encoder_decoder_cache,
9-
flatten_unflatten_for_dynamic_shapes,
10+
make_mamba_cache,
11+
make_sliding_window_cache,
1012
)
1113
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1214
from onnx_diagnostic.torch_export_patches.patch_inputs import (
@@ -132,6 +134,36 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
132134
self.string_type(c2, with_shape=True),
133135
)
134136

137+
def test_make_mamba_cache(self):
138+
cache = make_mamba_cache(
139+
[
140+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
141+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
142+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
143+
]
144+
)
145+
text = self.string_type(cache, with_shape=True)
146+
self.assertEqual(
147+
"MambaCache(conv_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4], "
148+
"ssm_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4])",
149+
text,
150+
)
151+
152+
def test_make_sliding_window_cache(self):
153+
cache = make_sliding_window_cache(
154+
[
155+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
156+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
157+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
158+
]
159+
)
160+
text = self.string_type(cache, with_shape=True)
161+
self.assertEqual(
162+
"SlidingWindowCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
163+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
164+
text,
165+
)
166+
135167

136168
if __name__ == "__main__":
137169
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.helpers.cache_helper import (
66
make_encoder_decoder_cache,
77
make_dynamic_cache,
8+
make_sliding_window_cache,
89
flatten_unflatten_for_dynamic_shapes,
910
)
1011
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -164,6 +165,52 @@ def test_base_model_output_unflatten_flatten(self):
164165
self.assertIsInstance(unflat, dict)
165166
self.assertEqual(list(unflat), ["last_hidden_state"])
166167

168+
@ignore_warnings(UserWarning)
169+
def test_base_sliding_window_cache_unflatten_flatten(self):
170+
cache = make_sliding_window_cache(
171+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
172+
)
173+
with bypass_export_some_errors():
174+
cache2 = torch_deepcopy([cache])
175+
self.assertEqualAny([cache], cache2)
176+
177+
@ignore_warnings(UserWarning)
178+
def test_sliding_window_cache_export(self):
179+
class Model(torch.nn.Module):
180+
def forward(self, cache):
181+
return cache.key_cache[0]
182+
183+
cache = make_sliding_window_cache(
184+
[
185+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
186+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
187+
]
188+
)
189+
model = Model()
190+
model(cache)
191+
DYN = torch.export.Dim.DYNAMIC
192+
ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]
193+
194+
with bypass_export_some_errors(patch_transformers=True):
195+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
196+
197+
@ignore_warnings(UserWarning)
198+
def test_sliding_window_cache_flatten(self):
199+
cache = make_sliding_window_cache(
200+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
201+
)
202+
with bypass_export_some_errors():
203+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
204+
self.assertEqual(
205+
"#2[T1s4x4x4x4,T1s4x4x4x4]",
206+
self.string_type(flat, with_shape=True),
207+
)
208+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
209+
self.assertEqual(
210+
self.string_type(cache, with_shape=True, with_min_max=True),
211+
self.string_type(cache2, with_shape=True, with_min_max=True),
212+
)
213+
167214

168215
if __name__ == "__main__":
169216
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import json
23
import sys
34
import textwrap
@@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]):
227228
print(f"task: {task_from_id(args.mid)}")
228229

229230

231+
class _ParseDict(argparse.Action):
232+
def __call__(self, parser, namespace, values, option_string=None):
233+
d = getattr(namespace, self.dest) or {}
234+
235+
if values:
236+
for item in values:
237+
split_items = item.split("=", 1)
238+
key = split_items[0].strip() # we remove blanks around keys, as is logical
239+
value = split_items[1]
240+
241+
d[key] = value
242+
243+
setattr(namespace, self.dest, d)
244+
245+
230246
def get_parser_validate() -> ArgumentParser:
231247
parser = ArgumentParser(
232248
prog="test",
@@ -297,6 +313,14 @@ def get_parser_validate() -> ArgumentParser:
297313
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
298314
parser.add_argument("--dtype", help="changes dtype if necessary")
299315
parser.add_argument("--device", help="changes the device if necessary")
316+
parser.add_argument(
317+
"--iop",
318+
metavar="KEY=VALUE",
319+
nargs="*",
320+
help="Additional input options, use to change the default "
321+
"inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322+
action=_ParseDict,
323+
)
300324
return parser
301325

302326

@@ -346,6 +370,7 @@ def _cmd_validate(argv: List[Any]):
346370
dump_folder=args.dump_folder,
347371
drop_inputs=None if not args.drop else args.drop.split(","),
348372
ortfusiontype=args.ortfusiontype,
373+
input_options=args.iop,
349374
)
350375
print("")
351376
print("-- summary --")

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def assertEqualAny(
920920
else:
921921
for e, g in zip(expected, value):
922922
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
923-
elif expected.__class__.__name__ == "DynamicCache":
923+
elif expected.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
924924
self.assertEqual(type(expected), type(value), msg=msg)
925925
atts = ["key_cache", "value_cache"]
926926
self.assertEqualAny(

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,8 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
2626
subtrees = []
2727
for subspec in spec.children_specs:
2828
end += subspec.num_leaves
29-
if use_dict and (subspec.type is dict or subspec.context):
30-
value = subspec.unflatten(flat[start:end])
31-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
32-
else:
33-
value = subspec.unflatten(flat[start:end])
34-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
29+
value = subspec.unflatten(flat[start:end])
30+
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
3531
subtrees.append(value)
3632
start = end
3733
if use_dict and (spec.type is dict or spec.context):
@@ -185,3 +181,36 @@ def __init__(self):
185181
)
186182
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
187183
return cache
184+
185+
186+
def make_sliding_window_cache(
187+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
188+
) -> transformers.cache_utils.MambaCache:
189+
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
190+
191+
class _config:
192+
def __init__(self):
193+
self.head_dim = key_value_pairs[0][0].shape[-1]
194+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
195+
self.num_hidden_layers = len(key_value_pairs)
196+
self.sliding_window = key_value_pairs[0][0].shape[2]
197+
198+
cache = transformers.cache_utils.SlidingWindowCache(
199+
_config(),
200+
max_batch_size=key_value_pairs[0][0].shape[0],
201+
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
202+
device=key_value_pairs[0][0].device,
203+
dtype=key_value_pairs[0][0].dtype,
204+
)
205+
for i in range(len(key_value_pairs)):
206+
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
207+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
208+
f"got {key_value_pairs[i][0].shape}"
209+
)
210+
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
211+
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
212+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
213+
f"got {key_value_pairs[i][1].shape}"
214+
)
215+
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
216+
return cache

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def string_type(
534534
print(f"[string_type] CACHE1:{type(obj)}")
535535
return f"MambaCache(conv_states={c}, ssm_states={d})"
536536

537-
if obj.__class__.__name__ == "DynamicCache":
537+
if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
538538
kc = string_type(
539539
obj.key_cache,
540540
with_shape=with_shape,

onnx_diagnostic/tasks/text_generation.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from typing import Any, Callable, Dict, Optional, Tuple, Union
22
import torch
3-
from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache
3+
import transformers
4+
from ..helpers.cache_helper import (
5+
make_dynamic_cache,
6+
make_mamba_cache,
7+
make_sliding_window_cache,
8+
)
49
from ..helpers.config_helper import update_config, check_hasattr, _pick
510

611
__TASK__ = "text-generation"
@@ -88,6 +93,10 @@ def get_inputs(
8893
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
8994

9095
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
96+
assert cls_cache in (
97+
"MambaCache",
98+
transformers.cache_utils.MambaCache,
99+
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
91100
seq_length_multiple = 8
92101
sequence_length = (
93102
(sequence_length + seq_length_multiple)
@@ -156,6 +165,13 @@ def get_inputs(
156165
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
157166
],
158167
}
168+
169+
make_cache = (
170+
make_sliding_window_cache
171+
if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
172+
else make_dynamic_cache
173+
)
174+
159175
inputs = dict(
160176
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
161177
torch.int64
@@ -166,7 +182,7 @@ def get_inputs(
166182
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
167183
.to(torch.int64)
168184
.expand((batch_size, -1)),
169-
past_key_values=make_dynamic_cache(
185+
past_key_values=make_cache(
170186
[
171187
(
172188
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),

0 commit comments

Comments
 (0)