Skip to content

Commit 400fb24

Browse files
committed
fix conversion issues
1 parent 962993e commit 400fb24

File tree

6 files changed

+47
-7
lines changed

6 files changed

+47
-7
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Change Logs
55
+++++
66

77

8+
* :pr:`200`: fixes patches for 4.55.1+, DynamicCache is no longer registered by default,
9+
this code moved to executorch.py in transformers
810
* :pr:`199`: delete hidden_size and num_attention_heads modification in a config
911
* :pr:`198`: support gpt-oss
1012
* :pr:`197`: updates CI for torch 2.8

_unittests/ut_helpers/test_bench_run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
make_configs,
1111
run_benchmark,
1212
)
13-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
13+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
1414

1515

1616
class TestBenchRun(ExtTestCase):
@@ -153,9 +153,10 @@ def test_max_diff(self):
153153
def test_max_diff_dynamic_cache(self):
154154
t1 = torch.tensor([0, 1], dtype=torch.float32)
155155
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
156+
dc = CacheKeyValue(cache)
156157
md = max_diff(
157158
(t1, cache),
158-
(t1, cache.key_cache[0], cache.value_cache[0]),
159+
(t1, dc.key_cache[0], dc.value_cache[0]),
159160
flatten=True,
160161
verbose=10,
161162
)

_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
make_static_cache,
99
make_sliding_window_cache,
1010
flatten_unflatten_for_dynamic_shapes,
11+
CacheKeyValue,
1112
)
1213
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
1314
torch_export_patches,
@@ -48,7 +49,8 @@ def test_encoder_decoder_cache_deepcopy(self):
4849
def test_encoder_decoder_cache_export(self):
4950
class Model(torch.nn.Module):
5051
def forward(self, cache):
51-
return cache.self_attention_cache.key_cache[0]
52+
att = CacheKeyValue(cache.self_attention_cache)
53+
return att.key_cache[0]
5254

5355
cache1 = make_dynamic_cache(
5456
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
@@ -88,6 +90,7 @@ def test_dynamic_cache_flatten(self):
8890
def test_dynamic_cache_export(self):
8991
class Model(torch.nn.Module):
9092
def forward(self, cache):
93+
cache = CacheKeyValue(cache)
9194
return cache.key_cache[0]
9295

9396
cache = make_dynamic_cache(
@@ -180,7 +183,8 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
180183
def test_sliding_window_cache_export(self):
181184
class Model(torch.nn.Module):
182185
def forward(self, cache):
183-
return cache.key_cache[0]
186+
dc = CacheKeyValue(cache)
187+
return dc.key_cache[0]
184188

185189
cache = make_sliding_window_cache(
186190
[
@@ -268,6 +272,7 @@ def test_static_cache(self):
268272
# export
269273
class Model(torch.nn.Module):
270274
def forward(self, cache):
275+
cache = CacheKeyValue(cache)
271276
return cache.key_cache[0]
272277

273278
model = Model()

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
107107
self.assertEqual(
108108
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
109109
)
110+
print("***", self.string_type(inputs, with_shape=True))
111+
print("---", type(model))
110112
with torch_export_patches(
111113
patch_transformers=True, verbose=1, stop_if_static=1
112114
) as modificator:

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def _catch_produce_guards_and_solve_constraints(
2727
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
2828
equalities_inputs: "EqualityConstraint", # noqa: F821
2929
original_signature: inspect.Signature,
30-
_is_torch_jit_trace: bool = False,
3130
verbose: int = 0,
31+
**kwargs,
3232
):
3333
try:
3434
return previous_function(
@@ -37,7 +37,7 @@ def _catch_produce_guards_and_solve_constraints(
3737
dynamic_shapes=dynamic_shapes,
3838
equalities_inputs=equalities_inputs,
3939
original_signature=original_signature,
40-
_is_torch_jit_trace=_is_torch_jit_trace,
40+
**kwargs,
4141
)
4242
except Exception as e:
4343
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
@@ -51,7 +51,7 @@ def _catch_produce_guards_and_solve_constraints(
5151
f"dynamic_shapes={dynamic_shapes}\n"
5252
f"equalities_inputs={equalities_inputs}\n"
5353
f"original_signature={original_signature}\n"
54-
f"_is_torch_jit_trace={_is_torch_jit_trace}\n"
54+
f"kwargs={kwargs}\n"
5555
f"exc={e}\ngm={gm}"
5656
)
5757
torch._dynamo.reset()

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@
2424
patch_masking_utils = False
2525

2626

27+
try:
28+
# transformers>= 4.55.1
29+
from transformers.cache_utils import DynamicLayer
30+
31+
patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
32+
except ImportError:
33+
patch_DynamicLayer = False
34+
2735
from ...ext_test_case import has_transformers
2836
from ...helpers.torch_helper import is_torchdynamo_exporting
2937

@@ -158,6 +166,20 @@ def patched_parse_processor_args(
158166
return processor_kwargs, remaining_kwargs
159167

160168

169+
if patch_DynamicLayer:
170+
171+
class patched_DynamicLayer:
172+
_PATCHES_ = ["lazy_initialization"]
173+
_PATCHED_CLASS_ = DynamicLayer
174+
175+
def lazy_initialization(self, key_states: torch.Tensor):
176+
self.dtype, self.device = key_states.dtype, key_states.device
177+
new_shape = list(key_states.shape)
178+
new_shape[-2] = 0
179+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
180+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
181+
182+
161183
def _patch_make_causal_mask(
162184
input_ids_shape: torch.Size,
163185
dtype: torch.dtype,
@@ -324,6 +346,14 @@ def update(
324346
self.key_cache[layer_idx] = key_states
325347
self.value_cache[layer_idx] = value_states
326348
else:
349+
torch._check(
350+
len(self.key_cache[layer_idx].shape) == len(key_states.shape),
351+
lambda: (
352+
f"Rank mismatch len(self.key_cache[layer_idx].shape)="
353+
f"{len(self.key_cache[layer_idx].shape)}, "
354+
f"len(key_states.shape)={len(key_states.shape)}"
355+
),
356+
)
327357
self.key_cache[layer_idx] = torch.cat(
328358
[self.key_cache[layer_idx], key_states], dim=-2
329359
)

0 commit comments

Comments
 (0)