Skip to content

Commit f92a334

Browse files
committed
fix cache
1 parent 85156ef commit f92a334

File tree

3 files changed

+87
-75
lines changed

3 files changed

+87
-75
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 65 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
ignore_warnings,
99
hide_stdout,
1010
requires_torch,
11-
has_transformers,
1211
)
1312
from onnx_diagnostic.helpers import string_type
1413
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
@@ -22,76 +21,72 @@ class TestOnnxExportErrors(ExtTestCase):
2221
@ignore_warnings(UserWarning)
2322
@hide_stdout()
2423
def test_export_dynamic_cache_update(self):
25-
values = [True, False] if has_transformers("4.50") else [False]
26-
for strict in self.subloop(values, verbose=1):
27-
28-
class SubModelCache(torch.nn.Module):
29-
def forward(self, cache):
30-
cc = CacheKeyValue(cache)
31-
# If not patched...
32-
# Fails with transformers>=4.54 because function ``parse_processor_args``
33-
# relies in inspect and the exporter is not very fond of that.
34-
# torch._dynamo.exc.Unsupported: id() with unsupported args
35-
# Explanation: Dynamo doesn't know how to trace id()
36-
# call with args
37-
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
38-
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
39-
# objects from outside the compiled region.
40-
# Hint: It may be possible to write Dynamo tracing rules for this code.
41-
d = cache.__class__()
42-
d.update(cc.key_cache[0] + 1, cc.value_cache[0] + 2, 0)
43-
d.update(cc.key_cache[0] + 3, cc.value_cache[0] + 5, 1)
44-
return d
45-
46-
class SubModel(torch.nn.Module):
47-
def forward(self, x, cache):
48-
cc = CacheKeyValue(cache)
49-
return x + cc.key_cache[0] + cc.value_cache[0]
50-
51-
class Model(torch.nn.Module):
52-
def __init__(self):
53-
super().__init__()
54-
self.sub = SubModel()
55-
self.subcache = SubModelCache()
56-
57-
def forward(self, x, cache):
58-
return self.sub(x, self.subcache(cache))
59-
60-
# no patch
61-
cache = make_dynamic_cache(
62-
[(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)]
24+
class SubModelCache(torch.nn.Module):
25+
def forward(self, cache):
26+
cc = CacheKeyValue(cache)
27+
# If not patched...
28+
# Fails with transformers>=4.54 because function ``parse_processor_args``
29+
# relies in inspect and the exporter is not very fond of that.
30+
# torch._dynamo.exc.Unsupported: id() with unsupported args
31+
# Explanation: Dynamo doesn't know how to trace id()
32+
# call with args
33+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
34+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
35+
# objects from outside the compiled region.
36+
# Hint: It may be possible to write Dynamo tracing rules for this code.
37+
d = cache.__class__()
38+
d.update(cc.key_cache[0] + 1, cc.value_cache[0] + 2, 0)
39+
d.update(cc.key_cache[0] + 3, cc.value_cache[0] + 5, 1)
40+
return d
41+
42+
class SubModel(torch.nn.Module):
43+
def forward(self, x, cache):
44+
cc = CacheKeyValue(cache)
45+
y = cc.key_cache[0] + cc.value_cache[0]
46+
return x + y
47+
48+
class Model(torch.nn.Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.sub = SubModel()
52+
self.subcache = SubModelCache()
53+
54+
def forward(self, x, cache):
55+
return self.sub(x, self.subcache(cache))
56+
57+
# no patch
58+
cache = make_dynamic_cache([(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)])
59+
model = Model()
60+
inputs = (torch.randn((5, 6, 5, 6)), cache)
61+
expected = model(*inputs)
62+
63+
DYN = torch.export.Dim.DYNAMIC
64+
65+
# patching
66+
with torch_export_patches(patch_transformers=True, verbose=10):
67+
got = model(*inputs)
68+
self.assertEqualArray(expected, got)
69+
ep = torch.export.export(
70+
model,
71+
inputs,
72+
dynamic_shapes=(
73+
{0: DYN, 2: DYN},
74+
[[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]],
75+
),
76+
strict=False,
6377
)
64-
model = Model()
65-
inputs = (torch.randn((5, 6, 5, 6)), cache)
66-
expected = model(*inputs)
67-
68-
DYN = torch.export.Dim.DYNAMIC
69-
70-
# patching
71-
with torch_export_patches(patch_transformers=True, verbose=10):
72-
got = model(*inputs)
73-
self.assertEqualArray(expected, got)
74-
ep = torch.export.export(
75-
model,
76-
inputs,
77-
dynamic_shapes=(
78-
{0: DYN, 2: DYN},
79-
[[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]],
80-
),
81-
strict=strict,
82-
)
83-
mod = ep.module()
84-
got = mod(*inputs)
85-
self.assertEqualArray(expected, got)
86-
87-
class MyInterpreter(torch.fx.Interpreter):
88-
def call_function(self, target, args, kwargs):
89-
res = super().call_function(target, args, kwargs)
90-
return res
91-
92-
args, _spec = torch.utils._pytree.tree_flatten(inputs)
93-
got = MyInterpreter(ep.module()).run(*args)
94-
self.assertEqualAny(expected, got)
78+
mod = ep.module()
79+
got = mod(*inputs)
80+
self.assertEqualArray(expected, got)
81+
82+
class MyInterpreter(torch.fx.Interpreter):
83+
def call_function(self, target, args, kwargs):
84+
res = super().call_function(target, args, kwargs)
85+
return res
86+
87+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
88+
got = MyInterpreter(ep.module()).run(*args)
89+
self.assertEqualAny(expected, got)
9590

9691
@ignore_warnings(UserWarning)
9792
@requires_torch(

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def make_dynamic_cache(
183183
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
184184
f"{len(key_value_pairs)} expected."
185185
)
186-
return cache
186+
return finalize_cache(cache)
187187

188188
else:
189189

@@ -335,7 +335,7 @@ def get_text_config(self):
335335
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
336336
f"{len(key_value_pairs)} expected."
337337
)
338-
return cache
338+
return finalize_cache(cache)
339339

340340

341341
def make_encoder_decoder_cache(
@@ -391,7 +391,7 @@ def get_text_config(self):
391391
f"got {key_value_pairs[i][1].shape}"
392392
)
393393
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
394-
return cache
394+
return finalize_cache(cache)
395395

396396

397397
def make_sliding_window_cache(
@@ -446,7 +446,7 @@ def get_text_config(self):
446446
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
447447
f"{len(key_value_pairs)} expected."
448448
)
449-
return cache
449+
return finalize_cache(cache)
450450

451451

452452
def make_hybrid_cache(
@@ -605,4 +605,21 @@ def get_text_config(self):
605605
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
606606
f"{len(key_value_pairs)} expected."
607607
)
608+
return finalize_cache(cache)
609+
610+
611+
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
612+
"""
613+
Ensures the created cache is consistent.
614+
Returns the cache modified inplace.
615+
"""
616+
if (
617+
hasattr(cache, "layer_class_to_replicate")
618+
and hasattr(cache, "layers")
619+
and cache.layers
620+
and not cache.layer_class_to_replicate
621+
):
622+
# This is used to expand the cache when it does not contains enough layers.
623+
# This is needed since transformers>4.55.3
624+
cache.layer_class_to_replicate = cache.layers[0].__class__
608625
return cache

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
3434
f"Inconsistencies in subset={subset}, found={values}, "
3535
f"it cannot be a {cls}, value={string_type(value)}"
3636
)
37-
cache_length = len(value.key_cache)
37+
cache_length = len(value.layers if hasattr(value, "layers") else value.key_cache)
3838
for v in subset.values():
3939
axes = v
4040
break

0 commit comments

Comments
 (0)