Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.2
+++++

* :pr:`166`: improves handling of StaticCache
* :pr:`165`: support for task text-to-image
* :pr:`162`: improves graphs rendering for historical data

Expand Down
10 changes: 6 additions & 4 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,13 @@ def test_make_static_cache(self):
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
],
max_cache_len=15,
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])
Expand All @@ -192,7 +193,8 @@ def test_unflatten_flatten_static_cache(self):
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
],
max_cache_len=6,
)
self.assertEqual(0, max_diff(c2, c2)["abs"])
self.assertIsInstance(c2, transformers.cache_utils.StaticCache)
Expand Down
66 changes: 65 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnx_diagnostic.helpers.cache_helper import (
make_encoder_decoder_cache,
make_dynamic_cache,
make_static_cache,
make_sliding_window_cache,
flatten_unflatten_for_dynamic_shapes,
)
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
self.assertEqualAny([cache], cache2)

@ignore_warnings(UserWarning)
@requires_torch("2.8")
@requires_torch("2.7.99")
def test_sliding_window_cache_export(self):
class Model(torch.nn.Module):
def forward(self, cache):
Expand Down Expand Up @@ -274,6 +275,69 @@ def forward(self, cache):
with torch_export_patches():
torch.export.export(model, (bo,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
@requires_torch("2.7.99")
def test_static_cache(self):
bo = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
],
max_cache_len=15,
)
self.assertEqual(bo.__class__.__name__, "StaticCache")
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
self.string_type(bo, with_shape=True),
)

with torch_export_patches():
# internal function
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "StaticCache")
self.assertEqualAny([bo], bo2)
self.assertEqual(
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
self.string_type(bo, with_shape=True),
)

# serialization
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#6[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7]",
self.string_type(flat, with_shape=True),
)
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(bo, with_shape=True, with_min_max=True),
self.string_type(bo2, with_shape=True, with_min_max=True),
)

# flatten_unflatten
flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, dict)
self.assertEqual(list(unflat), ["key_cache", "value_cache"])

# export
class Model(torch.nn.Module):
def forward(self, cache):
return cache.key_cache[0]

model = Model()
model(bo)
DYN = torch.export.Dim.DYNAMIC
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]

with torch_export_patches(patch_transformers=True, stop_if_static=1):
torch.export.export(model, (bo,), dynamic_shapes=(ds,))


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 12 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
ep = torch.export.export(Model(), inputs, dynamic_shapes=ds)
self.assertEqualArray(causal_mask, ep.module()(*inputs))

@requires_torch("2.7")
def test_export_unsqueeze(self):
class Model(torch.nn.Module):
def forward(self, x):
return x.unsqueeze(0).unsqueeze(2).unsqueeze(3)

x = torch.tensor([7.0, 8.0])
Model()(x)
DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
self.assertEqualArray(Model()(x), ep.module()(x))


if __name__ == "__main__":
unittest.main(verbosity=2)
10 changes: 10 additions & 0 deletions onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,16 @@ def assertEqualAny(
atol=atol,
rtol=rtol,
)
elif expected.__class__.__name__ == "StaticCache":
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqual(expected.max_cache_len, value.max_cache_len)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
elif expected.__class__.__name__ == "EncoderDecoderCache":
self.assertEqual(type(expected), type(value), msg=msg)
atts = ["self_attention_cache", "cross_attention_cache"]
Expand Down
32 changes: 21 additions & 11 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def make_dynamic_cache(

def make_static_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
max_cache_len: Optional[int] = None,
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
:param key_value_pairs: list of pairs of (key, values)
:param max_cache_len: max_cache_length or something inferred from the vector
:return: :class:`transformers.cache_utils.StaticCache`

Example:
Expand Down Expand Up @@ -190,24 +192,32 @@ def __init__(self):
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)

assert max_cache_len is not None, (
f"max_cache_len={max_cache_len} cannot be setup "
f"automatically yet from shape {key_value_pairs[0][0].shape}"
)
torch._check(
max_cache_len >= key_value_pairs[0][0].shape[2],
(
f"max_cache_len={max_cache_len} cannot be smaller "
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
f"{key_value_pairs[0][0].shape}"
),
)
cache = transformers.cache_utils.StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
max_cache_len=max_cache_len,
)
for i in range(len(key_value_pairs)):
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
assert (
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
d = key_value_pairs[i][1].shape[2]
cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
return cache


Expand Down
8 changes: 6 additions & 2 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
[t.to(to_value) for t in value.key_cache],
[t.to(to_value) for t in value.value_cache],
)
)
),
max_cache_len=value.max_cache_len,
)
if value.__class__.__name__ == "EncoderDecoderCache":
return make_encoder_decoder_cache(
Expand Down Expand Up @@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any:
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "StaticCache":
return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
return make_static_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
max_cache_len=value.max_cache_len,
)
if value.__class__.__name__ == "SlidingWindowCache":
return make_sliding_window_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
Expand Down
19 changes: 13 additions & 6 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_inputs(
sequence_length2 = seq_length_multiple

shapes = {
"input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
"input_ids": {0: batch, 1: "sequence_length"},
"attention_mask": {
0: batch,
1: "cache+seq", # cache_length + seq_length
Expand Down Expand Up @@ -188,18 +188,25 @@ def get_inputs(
(batch_size, num_key_value_heads, sequence_length2, head_dim)
).to(torch.bool),
cache_position=torch.arange(sequence_length2).to(torch.int64),
past_key_values=make_cache(
past_key_values=make_static_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
head_dim,
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
head_dim,
),
)
for i in range(num_hidden_layers)
]
],
max_cache_len=max(sequence_length + sequence_length2, head_dim),
),
)
else:
Expand Down Expand Up @@ -230,7 +237,7 @@ def get_inputs(
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
.to(torch.int64)
.expand((batch_size, -1)),
past_key_values=make_cache(
past_key_values=make_cache( # type: ignore[operator]
[
(
torch.randn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def flatten_static_cache(
cache: StaticCache,
) -> Tuple[List[Any], torch.utils._pytree.Context]:
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
f"Serialization doet not work when "
f"cache.max_cache_len={cache.max_cache_len} != "
f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
)
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
return [f[1] for f in flat], [f[0] for f in flat]

Expand All @@ -167,7 +172,9 @@ def unflatten_static_cache(
values: List[Any], context: torch.utils._pytree.Context, output_type=None
) -> StaticCache:
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
return make_static_cache(list(zip(values[0], values[1])))
return make_static_cache(
list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
)


####################
Expand Down
29 changes: 26 additions & 3 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
]
if bh_indices:
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
# reshape
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
dimensions = tuple(reversed(dimensions))
indices = tuple(shape.index(-1) for shape in dimensions)

# unsqueeze
udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]

def vector_mask_function(
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
):
assert len(args) == len(
dimensions
), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
assert len(args) == len(dimensions) == len(udimensions), (
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
f"and udimensions={udimensions}."
)
assert len(indices) == len(args), (
f"Mismatch between args={string_type(args)} and indices={indices}, "
f"they should have the same length."
)
for a in args:
assert (
a.ndim == 1
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
torch._check(a.shape[0] > 0)

new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
# new_args = [
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
# for a, dims in zip(args, udimensions)
# ]
max_shape = tuple(args[i].shape[0] for i in indices)
# if is_torchdynamo_exporting():
# for a in args:
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
# torch._check(a.shape[0] > 1)
expanded_args = [a.expand(max_shape) for a in new_args]
return mask_function(*expanded_args)

Expand Down
7 changes: 4 additions & 3 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,10 @@ def validate_model(
exported model returns the same outputs as the original one, otherwise,
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
"""
assert (
not rewrite or patch
), f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting"
assert not rewrite or patch, (
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "
f"if --no-patch was specified on the command line, --no-rewrite must be added."
)
summary = version_summary()
summary.update(
dict(
Expand Down
Loading