diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4eb46bf3..ce0b259b 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.4.2 ++++++ + +* :pr:`73`: supports MambaCache in max_diff, torch_deepcopy + 0.4.1 +++++ diff --git a/_doc/index.rst b/_doc/index.rst index e31ce2ae..d6fc9c9a 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -173,6 +173,7 @@ Size of the package: Older versions ++++++++++++++ +* `0.4.2 <../v0.4.2/index.html>`_ * `0.4.1 <../v0.4.1/index.html>`_ * `0.4.0 <../v0.4.0/index.html>`_ * `0.3.0 <../v0.3.0/index.html>`_ diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 0b752a0d..28db5f4c 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -145,8 +145,8 @@ def test_make_mamba_cache(self): ) text = self.string_type(cache, with_shape=True) self.assertEqual( - "MambaCache(conv_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4], " - "ssm_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4])", + "MambaCache(conv_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4], " + "ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])", text, ) diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index f67e87c8..b6a5e86a 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -2,8 +2,9 @@ import ml_dtypes import onnx import torch +import transformers from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout -from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.helpers import max_diff, string_type from onnx_diagnostic.helpers.torch_test_helper import ( dummy_llm, to_numpy, @@ -13,7 +14,12 @@ to_any, torch_deepcopy, ) -from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache +from onnx_diagnostic.helpers.cache_helper import ( + make_dynamic_cache, + make_encoder_decoder_cache, + make_mamba_cache, + make_sliding_window_cache, +) TFLOAT = onnx.TensorProto.FLOAT @@ -85,12 +91,15 @@ def test_to_any(self): at = to_any(a, torch.float16) self.assertIn("T10r", string_type(at)) - def test_torch_deepcopy(self): + def test_torch_deepcopy_cache_dce(self): c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) c2 = make_encoder_decoder_cache( make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) + cc = torch_deepcopy(c2) + self.assertEqual(type(c2), type(c2)) + self.assertEqual(max_diff(c2, cc)["abs"], 0) a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]} at = torch_deepcopy(a) hash1 = string_type(at, with_shape=True, with_min_max=True) @@ -98,6 +107,53 @@ def test_torch_deepcopy(self): hash2 = string_type(at, with_shape=True, with_min_max=True) self.assertEqual(hash1, hash2) + def test_torch_deepcopy_mamba_cache(self): + cache = make_mamba_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ) + at = torch_deepcopy(cache) + self.assertEqual(type(cache), type(at)) + self.assertEqual(max_diff(cache, at)["abs"], 0) + hash1 = string_type(at, with_shape=True, with_min_max=True) + cache.conv_states[0] += 1000 + hash2 = string_type(at, with_shape=True, with_min_max=True) + self.assertEqual(hash1, hash2) + + def test_torch_deepcopy_base_model_outputs(self): + bo = transformers.modeling_outputs.BaseModelOutput( + last_hidden_state=torch.rand((4, 4, 4)) + ) + at = torch_deepcopy(bo) + self.assertEqual(max_diff(bo, at)["abs"], 0) + self.assertEqual(type(bo), type(at)) + hash1 = string_type(at, with_shape=True, with_min_max=True) + bo.last_hidden_state[0] += 1000 + hash2 = string_type(at, with_shape=True, with_min_max=True) + self.assertEqual(hash1, hash2) + + def test_torch_deepcopy_sliding_windon_cache(self): + cache = make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + at = torch_deepcopy(cache) + self.assertEqual(type(cache), type(at)) + self.assertEqual(max_diff(cache, at)["abs"], 0) + hash1 = string_type(at, with_shape=True, with_min_max=True) + cache.key_cache[0] += 1000 + hash2 = string_type(at, with_shape=True, with_min_max=True) + self.assertEqual(hash1, hash2) + + def test_torch_deepcopy_none(self): + self.assertEmpty(torch_deepcopy(None)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index d56a8b55..68fe5b78 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.4.1" +__version__ = "0.4.2" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 2040e6ff..3f50ecf7 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -155,6 +155,7 @@ def make_mamba_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.MambaCache: "Creates a :class:`transformers.cache_utils.MambaCache`." + dtype = key_value_pairs[0][0].dtype class _config: def __init__(self): @@ -162,14 +163,23 @@ def __init__(self): self.conv_kernel = key_value_pairs[0][0].shape[-1] self.state_size = key_value_pairs[0][1].shape[-1] self.num_hidden_layers = len(key_value_pairs) - self.dtype = key_value_pairs[0][0].dtype + self.dtype = dtype cache = transformers.cache_utils.MambaCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, + dtype=dtype, ) for i in range(len(key_value_pairs)): + assert cache.conv_states[i].dtype == dtype, ( + f"Type mismatch for cache.conv_states[{i}].dtype=" + f"{cache.conv_states[i].dtype} != {dtype}" + ) + assert cache.ssm_states[i].dtype == dtype, ( + f"Type mismatch for cache.ssm_states[{i}].dtype=" + f"{cache.ssm_states[i].dtype} != {dtype}" + ) assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, ( f"Shape mismatch, expected {cache.conv_states[i].shape}, " f"got {key_value_pairs[i][0].shape}" diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index b61d0a48..b4ec9326 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1404,6 +1404,28 @@ def max_diff( f"level={level}" ) + if expected.__class__.__name__ == "SlidingWindowCache": + if got.__class__.__name__ == "SlidingWindowCache": + if verbose >= 6: + print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}") + return max_diff( + [expected.key_cache, expected.value_cache], + [got.key_cache, got.value_cache], + verbose=verbose, + ) + if isinstance(got, tuple) and len(got) == 2: + return max_diff( + [expected.key_cache, expected.value_cache], + [got[0], got[1]], + verbose=verbose, + ) + raise AssertionError( + f"SlidingWindowCache not fully implemented with classes " + f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " + f"and expected={string_type(expected)}, got={string_type(got)},\n" + f"level={level}" + ) + if expected.__class__.__name__ == "EncoderDecoderCache": if got.__class__.__name__ == "EncoderDecoderCache": if verbose >= 6: diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index e8ede458..2bc4f2e2 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -8,6 +8,7 @@ make_dynamic_cache, make_encoder_decoder_cache, make_sliding_window_cache, + make_mamba_cache, ) @@ -346,6 +347,8 @@ def torch_deepcopy(value: Any) -> Any: """ Makes a deepcopy. """ + if value is None: + return None if isinstance(value, (int, float, str)): return value if isinstance(value, tuple): @@ -376,6 +379,9 @@ def torch_deepcopy(value: Any) -> Any: torch_deepcopy(value.self_attention_cache), torch_deepcopy(value.cross_attention_cache), ) + if value.__class__.__name__ == "MambaCache": + return make_mamba_cache(list(zip(value.conv_states, value.ssm_states))) + if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: args, spec = torch.utils._pytree.tree_flatten(value) new_args = torch_deepcopy(args)