Skip to content

Commit ec103a4

Browse files
committed
Support MambaCache in torch_deepcopy
1 parent 8ce2aaa commit ec103a4

File tree

4 files changed

+83
-4
lines changed

4 files changed

+83
-4
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import ml_dtypes
33
import onnx
44
import torch
5+
import transformers
56
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
6-
from onnx_diagnostic.helpers import string_type
7+
from onnx_diagnostic.helpers import max_diff, string_type
78
from onnx_diagnostic.helpers.torch_test_helper import (
89
dummy_llm,
910
to_numpy,
@@ -13,7 +14,12 @@
1314
to_any,
1415
torch_deepcopy,
1516
)
16-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
17+
from onnx_diagnostic.helpers.cache_helper import (
18+
make_dynamic_cache,
19+
make_encoder_decoder_cache,
20+
make_mamba_cache,
21+
make_sliding_window_cache,
22+
)
1723

1824
TFLOAT = onnx.TensorProto.FLOAT
1925

@@ -85,19 +91,66 @@ def test_to_any(self):
8591
at = to_any(a, torch.float16)
8692
self.assertIn("T10r", string_type(at))
8793

88-
def test_torch_deepcopy(self):
94+
def test_torch_deepcopy_cache_dce(self):
8995
c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
9096
c2 = make_encoder_decoder_cache(
9197
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
9298
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
9399
)
100+
cc = torch_deepcopy(c2)
101+
self.assertEqual(type(c2), type(c2))
102+
self.assertEqual(max_diff(c2, cc)["abs"], 0)
94103
a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]}
95104
at = torch_deepcopy(a)
96105
hash1 = string_type(at, with_shape=True, with_min_max=True)
97106
c1.key_cache[0] += 1000
98107
hash2 = string_type(at, with_shape=True, with_min_max=True)
99108
self.assertEqual(hash1, hash2)
100109

110+
def test_torch_deepcopy_mamba_cache(self):
111+
cache = make_mamba_cache(
112+
[
113+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
114+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
115+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
116+
]
117+
)
118+
at = torch_deepcopy(cache)
119+
self.assertEqual(type(cache), type(at))
120+
self.assertEqual(max_diff(cache, at)["abs"], 0)
121+
hash1 = string_type(at, with_shape=True, with_min_max=True)
122+
cache.conv_states[0] += 1000
123+
hash2 = string_type(at, with_shape=True, with_min_max=True)
124+
self.assertEqual(hash1, hash2)
125+
126+
def test_torch_deepcopy_base_model_outputs(self):
127+
bo = transformers.modeling_outputs.BaseModelOutput(
128+
last_hidden_state=torch.rand((4, 4, 4))
129+
)
130+
at = torch_deepcopy(bo)
131+
self.assertEqual(max_diff(bo, at)["abs"], 0)
132+
self.assertEqual(type(bo), type(at))
133+
hash1 = string_type(at, with_shape=True, with_min_max=True)
134+
bo.last_hidden_state[0] += 1000
135+
hash2 = string_type(at, with_shape=True, with_min_max=True)
136+
self.assertEqual(hash1, hash2)
137+
138+
def test_torch_deepcopy_sliding_windon_cache(self):
139+
cache = make_sliding_window_cache(
140+
[
141+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
142+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
143+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
144+
]
145+
)
146+
at = torch_deepcopy(cache)
147+
self.assertEqual(type(cache), type(at))
148+
self.assertEqual(max_diff(cache, at)["abs"], 0)
149+
hash1 = string_type(at, with_shape=True, with_min_max=True)
150+
cache.key_cache[0] += 1000
151+
hash2 = string_type(at, with_shape=True, with_min_max=True)
152+
self.assertEqual(hash1, hash2)
153+
101154

102155
if __name__ == "__main__":
103156
unittest.main(verbosity=2)

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.4.1"
6+
__version__ = "0.4.2"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/helpers/helper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,28 @@ def max_diff(
14041404
f"level={level}"
14051405
)
14061406

1407+
if expected.__class__.__name__ == "SlidingWindowCache":
1408+
if got.__class__.__name__ == "SlidingWindowCache":
1409+
if verbose >= 6:
1410+
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1411+
return max_diff(
1412+
[expected.key_cache, expected.value_cache],
1413+
[got.key_cache, got.value_cache],
1414+
verbose=verbose,
1415+
)
1416+
if isinstance(got, tuple) and len(got) == 2:
1417+
return max_diff(
1418+
[expected.key_cache, expected.value_cache],
1419+
[got[0], got[1]],
1420+
verbose=verbose,
1421+
)
1422+
raise AssertionError(
1423+
f"SlidingWindowCache not fully implemented with classes "
1424+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1425+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
1426+
f"level={level}"
1427+
)
1428+
14071429
if expected.__class__.__name__ == "EncoderDecoderCache":
14081430
if got.__class__.__name__ == "EncoderDecoderCache":
14091431
if verbose >= 6:

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
make_dynamic_cache,
99
make_encoder_decoder_cache,
1010
make_sliding_window_cache,
11+
make_mamba_cache,
1112
)
1213

1314

@@ -376,6 +377,9 @@ def torch_deepcopy(value: Any) -> Any:
376377
torch_deepcopy(value.self_attention_cache),
377378
torch_deepcopy(value.cross_attention_cache),
378379
)
380+
if value.__class__.__name__ == "MambaCache":
381+
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
382+
379383
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
380384
args, spec = torch.utils._pytree.tree_flatten(value)
381385
new_args = torch_deepcopy(args)

0 commit comments

Comments
 (0)