diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 72605848..74be09f0 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`166`: improves handling of StaticCache * :pr:`165`: support for task text-to-image * :pr:`162`: improves graphs rendering for historical data diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 56ead0ae..c0dab99f 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -175,12 +175,13 @@ def test_make_static_cache(self): (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - ] + ], + max_cache_len=15, ) text = self.string_type(cache, with_shape=True) self.assertEqual( - "StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], " - "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])", + "StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], " + "value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])", text, ) self.assertEqual(0, max_diff(cache, cache)["abs"]) @@ -192,7 +193,8 @@ def test_unflatten_flatten_static_cache(self): (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - ] + ], + max_cache_len=6, ) self.assertEqual(0, max_diff(c2, c2)["abs"]) self.assertIsInstance(c2, transformers.cache_utils.StaticCache) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index fc7beaa4..9469cefd 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -10,6 +10,7 @@ from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, + make_static_cache, make_sliding_window_cache, flatten_unflatten_for_dynamic_shapes, ) @@ -180,7 +181,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self): self.assertEqualAny([cache], cache2) @ignore_warnings(UserWarning) - @requires_torch("2.8") + @requires_torch("2.7.99") def test_sliding_window_cache_export(self): class Model(torch.nn.Module): def forward(self, cache): @@ -274,6 +275,69 @@ def forward(self, cache): with torch_export_patches(): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + @ignore_warnings(UserWarning) + @requires_torch("2.7.99") + def test_static_cache(self): + bo = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + max_cache_len=15, + ) + self.assertEqual(bo.__class__.__name__, "StaticCache") + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual( + "StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], " + "value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])", + self.string_type(bo, with_shape=True), + ) + + with torch_export_patches(): + # internal function + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual(bo2[0].__class__.__name__, "StaticCache") + self.assertEqualAny([bo], bo2) + self.assertEqual( + "StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], " + "value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])", + self.string_type(bo, with_shape=True), + ) + + # serialization + flat, _spec = torch.utils._pytree.tree_flatten(bo) + self.assertEqual( + "#6[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7]", + self.string_type(flat, with_shape=True), + ) + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(bo, with_shape=True, with_min_max=True), + self.string_type(bo2, with_shape=True, with_min_max=True), + ) + + # flatten_unflatten + flat, _spec = torch.utils._pytree.tree_flatten(bo) + unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) + self.assertIsInstance(unflat, dict) + self.assertEqual(list(unflat), ["key_cache", "value_cache"]) + + # export + class Model(torch.nn.Module): + def forward(self, cache): + return cache.key_cache[0] + + model = Model() + model(bo) + DYN = torch.export.Dim.DYNAMIC + ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] + + with torch_export_patches(patch_transformers=True, stop_if_static=1): + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 04780181..99c2465a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -204,6 +204,18 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange): ep = torch.export.export(Model(), inputs, dynamic_shapes=ds) self.assertEqualArray(causal_mask, ep.module()(*inputs)) + @requires_torch("2.7") + def test_export_unsqueeze(self): + class Model(torch.nn.Module): + def forward(self, x): + return x.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + x = torch.tensor([7.0, 8.0]) + Model()(x) + DYN = torch.export.Dim.DYNAMIC + ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},)) + self.assertEqualArray(Model()(x), ep.module()(x)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index f65f6ac8..242b0ed4 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -976,6 +976,16 @@ def assertEqualAny( atol=atol, rtol=rtol, ) + elif expected.__class__.__name__ == "StaticCache": + self.assertEqual(type(expected), type(value), msg=msg) + self.assertEqual(expected.max_cache_len, value.max_cache_len) + atts = ["key_cache", "value_cache"] + self.assertEqualAny( + {k: expected.__dict__.get(k, None) for k in atts}, + {k: value.__dict__.get(k, None) for k in atts}, + atol=atol, + rtol=rtol, + ) elif expected.__class__.__name__ == "EncoderDecoderCache": self.assertEqual(type(expected), type(value), msg=msg) atts = ["self_attention_cache", "cross_attention_cache"] diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index ae4556fd..759fce17 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -154,10 +154,12 @@ def make_dynamic_cache( def make_static_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + max_cache_len: Optional[int] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.StaticCache`. :param key_value_pairs: list of pairs of (key, values) + :param max_cache_len: max_cache_length or something inferred from the vector :return: :class:`transformers.cache_utils.StaticCache` Example: @@ -190,24 +192,32 @@ def __init__(self): self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) + assert max_cache_len is not None, ( + f"max_cache_len={max_cache_len} cannot be setup " + f"automatically yet from shape {key_value_pairs[0][0].shape}" + ) + torch._check( + max_cache_len >= key_value_pairs[0][0].shape[2], + ( + f"max_cache_len={max_cache_len} cannot be smaller " + f"shape[2]={key_value_pairs[0][0].shape[2]} in shape " + f"{key_value_pairs[0][0].shape}" + ), + ) cache = transformers.cache_utils.StaticCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, dtype=key_value_pairs[0][0].dtype, - max_cache_len=key_value_pairs[0][0].shape[2], + max_cache_len=max_cache_len, ) for i in range(len(key_value_pairs)): - assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, ( - f"Shape mismatch, expected {cache.key_cache[i].shape}, " - f"got {key_value_pairs[i][0].shape}" - ) - cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] - assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, ( - f"Shape mismatch, expected {cache.value_cache[i].shape}, " - f"got {key_value_pairs[i][1].shape}" - ) - cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + assert ( + key_value_pairs[i][0].shape == key_value_pairs[i][1].shape + ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}" + d = key_value_pairs[i][1].shape[2] + cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1] return cache diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index d4437fc4..be20beb4 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: [t.to(to_value) for t in value.key_cache], [t.to(to_value) for t in value.value_cache], ) - ) + ), + max_cache_len=value.max_cache_len, ) if value.__class__.__name__ == "EncoderDecoderCache": return make_encoder_decoder_cache( @@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any: torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) if value.__class__.__name__ == "StaticCache": - return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache)))) + return make_static_cache( + torch_deepcopy(list(zip(value.key_cache, value.value_cache))), + max_cache_len=value.max_cache_len, + ) if value.__class__.__name__ == "SlidingWindowCache": return make_sliding_window_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index aee5f2bb..a960505b 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -109,7 +109,7 @@ def get_inputs( sequence_length2 = seq_length_multiple shapes = { - "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, + "input_ids": {0: batch, 1: "sequence_length"}, "attention_mask": { 0: batch, 1: "cache+seq", # cache_length + seq_length @@ -188,18 +188,25 @@ def get_inputs( (batch_size, num_key_value_heads, sequence_length2, head_dim) ).to(torch.bool), cache_position=torch.arange(sequence_length2).to(torch.int64), - past_key_values=make_cache( + past_key_values=make_static_cache( [ ( torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, + num_key_value_heads, + sequence_length + sequence_length2, + head_dim, ), torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, + num_key_value_heads, + sequence_length + sequence_length2, + head_dim, ), ) for i in range(num_hidden_layers) - ] + ], + max_cache_len=max(sequence_length + sequence_length2, head_dim), ), ) else: @@ -230,7 +237,7 @@ def get_inputs( position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) .to(torch.int64) .expand((batch_size, -1)), - past_key_values=make_cache( + past_key_values=make_cache( # type: ignore[operator] [ ( torch.randn( diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py index 32a7415c..9557c15f 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py @@ -151,6 +151,11 @@ def flatten_static_cache( cache: StaticCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], ( + f"Serialization doet not work when " + f"cache.max_cache_len={cache.max_cache_len} != " + f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}" + ) flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] return [f[1] for f in flat], [f[0] for f in flat] @@ -167,7 +172,9 @@ def unflatten_static_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> StaticCache: """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" - return make_static_cache(list(zip(values[0], values[1]))) + return make_static_cache( + list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2] + ) #################### diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index a650feb3..017b1985 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -20,18 +20,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> ] if bh_indices: dimensions.extend([(None, 0, None, None), (0, None, None, None)]) + # reshape dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] dimensions = tuple(reversed(dimensions)) indices = tuple(shape.index(-1) for shape in dimensions) + # unsqueeze + udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions] + def vector_mask_function( *args, mask_function=mask_function, dimensions=dimensions, indices=indices ): - assert len(args) == len( - dimensions - ), f"Mismatch between args={string_type(args)} and dimensions={dimensions}" + assert len(args) == len(dimensions) == len(udimensions), ( + f"Mismatch between args={string_type(args)} and dimensions={dimensions} " + f"and udimensions={udimensions}." + ) + assert len(indices) == len(args), ( + f"Mismatch between args={string_type(args)} and indices={indices}, " + f"they should have the same length." + ) + for a in args: + assert ( + a.ndim == 1 + ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" + torch._check(a.shape[0] > 0) + new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] + # new_args = [ + # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) + # for a, dims in zip(args, udimensions) + # ] max_shape = tuple(args[i].shape[0] for i in indices) + # if is_torchdynamo_exporting(): + # for a in args: + # # The exporter should export with a dimension > 1 to make sure it is dynamic. + # torch._check(a.shape[0] > 1) expanded_args = [a.expand(max_shape) for a in new_args] return mask_function(*expanded_args) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 03011e4b..6736359f 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -346,9 +346,10 @@ def validate_model( exported model returns the same outputs as the original one, otherwise, :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used. """ - assert ( - not rewrite or patch - ), f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting" + assert not rewrite or patch, ( + f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, " + f"if --no-patch was specified on the command line, --no-rewrite must be added." + ) summary = version_summary() summary.update( dict(