Skip to content

Commit bfac0a1

Browse files
committed
fix caches
1 parent e89379d commit bfac0a1

File tree

4 files changed

+81
-10
lines changed

4 files changed

+81
-10
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
skipif_ci_windows,
1111
hide_stdout,
1212
requires_onnx,
13+
requires_transformers,
1314
)
1415
from onnx_diagnostic.helpers.helper import (
1516
string_type,
@@ -40,7 +41,13 @@
4041
onnx_dtype_to_torch_dtype,
4142
torch_dtype_to_onnx_dtype,
4243
)
43-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
44+
from onnx_diagnostic.helpers.cache_helper import (
45+
make_dynamic_cache,
46+
make_encoder_decoder_cache,
47+
make_static_cache,
48+
make_hybrid_cache,
49+
make_sliding_window_cache,
50+
)
4451
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
4552

4653

@@ -584,11 +591,55 @@ def test_flatten_encoder_decoder_cache(self):
584591
s = string_type(inputs)
585592
self.assertIn("EncoderDecoderCache", s)
586593

587-
def test_string_typeçconfig(self):
594+
def test_string_type_config(self):
588595
conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True)
589596
s = string_type(conf)
590597
self.assertStartsWith("PhiConfig(**{", s)
591598

599+
@requires_transformers("4.55")
600+
def test_max_diff_causal_output(self):
601+
from transformers.modeling_outputs import CausalLMOutputWithPast
602+
603+
logits = torch.rand((3, 4))
604+
cache = make_dynamic_cache([(torch.rand((3, 4)), torch.rand((3, 4)))])
605+
out1 = CausalLMOutputWithPast(logits=logits, past_key_values=cache)
606+
out2 = CausalLMOutputWithPast(logits=logits, past_key_values=cache)
607+
self.assertEqual(max_diff(out1, out2)["abs"], 0)
608+
self.assertEqual(
609+
max_diff(out1, [logits, cache.layers[0].keys, cache.layers[0].values])["abs"], 0
610+
)
611+
612+
def test_max_diff_others(self):
613+
t = torch.rand((3, 4))
614+
self.assertEqual(max_diff(t, t)["abs"], 0)
615+
self.assertEqual(max_diff([t], [t])["abs"], 0)
616+
self.assertEqual(max_diff([t], (t,))["abs"], 0)
617+
self.assertEqual(max_diff((t,), [t])["abs"], 0)
618+
self.assertEqual(max_diff((t,), (t,))["abs"], 0)
619+
self.assertEqual(max_diff({"t": t}, {"t": t})["abs"], 0)
620+
621+
def test_max_diff_caches(self):
622+
cache = make_dynamic_cache([(torch.rand((3, 4)), torch.rand((3, 4)))])
623+
self.assertEqual(max_diff(cache, cache)["abs"], 0)
624+
cache = make_static_cache(
625+
[(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))], max_cache_len=3
626+
)
627+
self.assertEqual(max_diff(cache, cache)["abs"], 0)
628+
cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))])
629+
self.assertEqual(max_diff(cache, cache)["abs"], 0)
630+
cache = make_sliding_window_cache(
631+
[(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]
632+
)
633+
self.assertEqual(max_diff(cache, cache)["abs"], 0)
634+
cache = make_encoder_decoder_cache(cache, cache)
635+
self.assertEqual(max_diff(cache, cache)["abs"], 0)
636+
637+
def test_max_diff_caches_flat(self):
638+
data = [(torch.rand((3, 4)), torch.rand((3, 4)))]
639+
cache1 = make_dynamic_cache(data)
640+
cache2 = make_dynamic_cache([*data[0]])
641+
self.assertEqual(max_diff(cache1, cache2)["abs"], 0)
642+
592643

593644
if __name__ == "__main__":
594645
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_validate_tiny_llms_bfloat16(self):
3939
self.assertIn("onnx_filename", data)
4040

4141
@requires_transformers("4.53")
42-
@requires_torch("2.7.99")
42+
@requires_torch("2.8.99")
4343
@requires_experimental()
4444
@hide_stdout()
4545
def test_validate_microsoft_phi4_reasoning(self):
@@ -60,7 +60,7 @@ def test_validate_microsoft_phi4_reasoning(self):
6060
self.assertIn("onnx_filename", data)
6161

6262
@requires_transformers("4.53")
63-
@requires_torch("2.7.99")
63+
@requires_torch("2.8.99")
6464
@requires_experimental()
6565
@hide_stdout()
6666
def test_validate_microsoft_phi3_mini_128k(self):

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, List, Optional, Tuple
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
import packaging.version as pv
33
import torch
44
import transformers
@@ -152,10 +152,18 @@ def make_dynamic_shapes_kv_cache(
152152
return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
153153

154154

155+
def _preprocess_key_value_pairs(
156+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
157+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
158+
if not key_value_pairs or isinstance(key_value_pairs[0], tuple):
159+
return key_value_pairs
160+
return list(zip(key_value_pairs[::2], key_value_pairs[1::2]))
161+
162+
155163
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
156164

157165
def make_dynamic_cache(
158-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
166+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
159167
) -> transformers.cache_utils.DynamicCache:
160168
"""
161169
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -191,6 +199,7 @@ def make_dynamic_cache(
191199
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
192200
are supported.
193201
"""
202+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
194203
if (
195204
key_value_pairs
196205
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
@@ -230,7 +239,7 @@ def make_dynamic_cache(
230239
else:
231240

232241
def make_dynamic_cache(
233-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
242+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
234243
) -> transformers.cache_utils.DynamicCache:
235244
"""
236245
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -262,14 +271,15 @@ def make_dynamic_cache(
262271
)
263272
print(string_type(past_key_values, with_shape=True))
264273
"""
274+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
265275
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
266276
for i, (key, value) in enumerate(key_value_pairs):
267277
cache.update(key, value, i)
268278
return cache
269279

270280

271281
def make_static_cache(
272-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
282+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
273283
max_cache_len: Optional[int] = None,
274284
) -> transformers.cache_utils.DynamicCache:
275285
"""
@@ -302,6 +312,7 @@ def make_static_cache(
302312
)
303313
print(string_type(past_key_values, with_shape=True))
304314
"""
315+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
305316

306317
class _config:
307318
def __init__(self):
@@ -444,9 +455,10 @@ def get_text_config(self, *args, **kwargs):
444455

445456

446457
def make_sliding_window_cache(
447-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
458+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
448459
) -> transformers.cache_utils.SlidingWindowCache:
449460
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
461+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
450462

451463
class _config:
452464
def __init__(self):
@@ -499,7 +511,7 @@ def get_text_config(self, *args, **kwargs):
499511

500512

501513
def make_hybrid_cache(
502-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
514+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
503515
max_cache_len: Optional[int] = None,
504516
max_batch_size: Optional[int] = None,
505517
sliding_window: Optional[int] = None,
@@ -584,6 +596,7 @@ def make_hybrid_cache(
584596
self.key_cache.append(new_layer_key_cache)
585597
self.value_cache.append(new_layer_value_cache)
586598
"""
599+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
587600
layer_types = None
588601
if key_value_pairs:
589602
assert (

onnx_diagnostic/helpers/helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,13 @@ def max_diff(
10641064
f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
10651065
f"? {string_type(got)}"
10661066
)
1067+
if got.__class__.__name__ == "CausalLMOutputWithPast":
1068+
return max_diff(
1069+
[expected.logits, *flatten_object(expected.past_key_values)],
1070+
[got.logits, *flatten_object(got.past_key_values)],
1071+
debug_info=_debug(expected.__class__.__name__),
1072+
**_dkws,
1073+
)
10671074
return max_diff(
10681075
[expected.logits, *flatten_object(expected.past_key_values)],
10691076
got,

0 commit comments

Comments
 (0)