Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.11', '3.12']
transformers: ['4.48.3', '4.51.2', 'main']
transformers: ['4.48.3', '4.51.3', 'main']
torch: ['2.6', 'main']

steps:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.4.0
+++++

* :pr:`65`: support SlidingWindowCache
* :pr:`63`: support option ``--trained``
* :pr:`61`: improves dynamic shapes for EncoderDecoderCache
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
Expand Down
2 changes: 2 additions & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
("py:class", "transformers.cache_utils.DynamicCache"),
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
("py:class", "transformers.cache_utils.MambaCache"),
("py:class", "transformers.cache_utils.SlidingWindowCache"),
("py:class", "transformers.configuration_utils.PretrainedConfig"),
("py:func", "torch.export._draft_export.draft_export"),
("py:func", "torch._export.tools.report_exportability"),
Expand Down Expand Up @@ -187,6 +188,7 @@
"ExecuTorch": "https://pytorch.org/executorch/stable/intro-overview.html",
"ExecuTorch Runtime Python API Reference": "https://pytorch.org/executorch/stable/runtime-python-api-reference.html",
"ExecuTorch Tutorial": "https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html",
"experimental-experiment": "https://sdpython.github.io/doc/experimental-experiment/dev/",
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
"FunctionProto": "https://onnx.ai/onnx/api/classes.html#functionproto",
"graph break": "https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks",
Expand Down
37 changes: 35 additions & 2 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import unittest
import torch
import transformers
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
flatten_unflatten_for_dynamic_shapes,
make_dynamic_cache,
make_encoder_decoder_cache,
flatten_unflatten_for_dynamic_shapes,
make_mamba_cache,
make_sliding_window_cache,
)
from onnx_diagnostic.export import CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches.patch_inputs import (
Expand Down Expand Up @@ -132,6 +134,37 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
self.string_type(c2, with_shape=True),
)

@requires_transformers("4.51") # the structure changes
def test_make_mamba_cache(self):
cache = make_mamba_cache(
[
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
]
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"MambaCache(conv_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4], "
"ssm_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4])",
text,
)

def test_make_sliding_window_cache(self):
cache = make_sliding_window_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"SlidingWindowCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
text,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
50 changes: 49 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest
import torch
from transformers.modeling_outputs import BaseModelOutput
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
from onnx_diagnostic.helpers.cache_helper import (
make_encoder_decoder_cache,
make_dynamic_cache,
make_sliding_window_cache,
flatten_unflatten_for_dynamic_shapes,
)
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
Expand Down Expand Up @@ -164,6 +165,53 @@ def test_base_model_output_unflatten_flatten(self):
self.assertIsInstance(unflat, dict)
self.assertEqual(list(unflat), ["last_hidden_state"])

@ignore_warnings(UserWarning)
def test_base_sliding_window_cache_unflatten_flatten(self):
cache = make_sliding_window_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
)
with bypass_export_some_errors():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)

@ignore_warnings(UserWarning)
@requires_torch("2.7")
def test_sliding_window_cache_export(self):
class Model(torch.nn.Module):
def forward(self, cache):
return cache.key_cache[0]

cache = make_sliding_window_cache(
[
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
]
)
model = Model()
model(cache)
DYN = torch.export.Dim.DYNAMIC
ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]

with bypass_export_some_errors(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
def test_sliding_window_cache_flatten(self):
cache = make_sliding_window_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
)
with bypass_export_some_errors():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#2[T1s4x4x4x4,T1s4x4x4x4]",
self.string_type(flat, with_shape=True),
)
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(cache, with_shape=True, with_min_max=True),
self.string_type(cache2, with_shape=True, with_min_max=True),
)


if __name__ == "__main__":
unittest.main(verbosity=2)
25 changes: 25 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json
import sys
import textwrap
Expand Down Expand Up @@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]):
print(f"task: {task_from_id(args.mid)}")


class _ParseDict(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
d = getattr(namespace, self.dest) or {}

if values:
for item in values:
split_items = item.split("=", 1)
key = split_items[0].strip() # we remove blanks around keys, as is logical
value = split_items[1]

d[key] = value

setattr(namespace, self.dest, d)


def get_parser_validate() -> ArgumentParser:
parser = ArgumentParser(
prog="test",
Expand Down Expand Up @@ -297,6 +313,14 @@ def get_parser_validate() -> ArgumentParser:
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
parser.add_argument("--dtype", help="changes dtype if necessary")
parser.add_argument("--device", help="changes the device if necessary")
parser.add_argument(
"--iop",
metavar="KEY=VALUE",
nargs="*",
help="Additional input options, use to change the default "
"inputs use to export, example: --iop cls_cache=SlidingWindowCache",
action=_ParseDict,
)
return parser


Expand Down Expand Up @@ -346,6 +370,7 @@ def _cmd_validate(argv: List[Any]):
dump_folder=args.dump_folder,
drop_inputs=None if not args.drop else args.drop.split(","),
ortfusiontype=args.ortfusiontype,
input_options=args.iop,
)
print("")
print("-- summary --")
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def assertEqualAny(
else:
for e, g in zip(expected, value):
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
elif expected.__class__.__name__ == "DynamicCache":
elif expected.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
self.assertEqual(type(expected), type(value), msg=msg)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
Expand Down
41 changes: 35 additions & 6 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
if use_dict and (subspec.type is dict or subspec.context):
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
else:
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
subtrees.append(value)
start = end
if use_dict and (spec.type is dict or spec.context):
Expand Down Expand Up @@ -185,3 +181,36 @@ def __init__(self):
)
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
return cache


def make_sliding_window_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.MambaCache:
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."

class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
self.sliding_window = key_value_pairs[0][0].shape[2]

cache = transformers.cache_utils.SlidingWindowCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
)
for i in range(len(key_value_pairs)):
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def string_type(
print(f"[string_type] CACHE1:{type(obj)}")
return f"MambaCache(conv_states={c}, ssm_states={d})"

if obj.__class__.__name__ == "DynamicCache":
if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
kc = string_type(
obj.key_cache,
with_shape=with_shape,
Expand Down
10 changes: 9 additions & 1 deletion onnx_diagnostic/helpers/torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import numpy as np
import torch
from .helper import string_type
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from .cache_helper import (
make_dynamic_cache,
make_encoder_decoder_cache,
make_sliding_window_cache,
)


def _forward_(*args, _f=None, _context=None, **kwargs):
Expand Down Expand Up @@ -363,6 +367,10 @@ def torch_deepcopy(value: Any) -> Any:
return make_dynamic_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "SlidingWindowCache":
return make_sliding_window_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "EncoderDecoderCache":
return make_encoder_decoder_cache(
torch_deepcopy(value.self_attention_cache),
Expand Down
20 changes: 18 additions & 2 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache
import transformers
from ..helpers.cache_helper import (
make_dynamic_cache,
make_mamba_cache,
make_sliding_window_cache,
)
from ..helpers.config_helper import update_config, check_hasattr, _pick

__TASK__ = "text-generation"
Expand Down Expand Up @@ -88,6 +93,10 @@ def get_inputs(
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)

if config is not None and config.__class__.__name__ == "FalconMambaConfig":
assert cls_cache in (
"MambaCache",
transformers.cache_utils.MambaCache,
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
seq_length_multiple = 8
sequence_length = (
(sequence_length + seq_length_multiple)
Expand Down Expand Up @@ -156,6 +165,13 @@ def get_inputs(
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
],
}

make_cache = (
make_sliding_window_cache
if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
else make_dynamic_cache
)

inputs = dict(
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
torch.int64
Expand All @@ -166,7 +182,7 @@ def get_inputs(
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
.to(torch.int64)
.expand((batch_size, -1)),
past_key_values=make_dynamic_cache(
past_key_values=make_cache(
[
(
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
Expand Down
Loading
Loading