Skip to content

Commit 8cb8ec5

Browse files
committed
fix registration issue
1 parent d1b8ba7 commit 8cb8ec5

File tree

2 files changed

+216
-21
lines changed

2 files changed

+216
-21
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 203 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import copy
2+
import os
23
import unittest
4+
from typing import Any, Dict, List, Tuple
35
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
6+
from onnx_diagnostic.ext_test_case import (
7+
ExtTestCase,
8+
ignore_warnings,
9+
hide_stdout,
10+
requires_torch,
11+
)
512
from onnx_diagnostic.helpers import string_type
613
from onnx_diagnostic.cache_helpers import make_dynamic_cache
714
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -45,21 +52,12 @@ def forward(self, x, cache):
4552
expected = model(*inputs)
4653

4754
DYN = torch.export.Dim.DYNAMIC
48-
ep = torch.export.export(
49-
model,
50-
inputs,
51-
dynamic_shapes=({0: DYN, 2: DYN}, [[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]]),
52-
strict=strict,
53-
)
54-
mod = ep.module()
55-
got = mod(*inputs)
56-
self.assertEqualArray(expected, got)
5755

5856
# patching
5957
with bypass_export_some_errors(patch_transformers=True):
6058
got = model(*inputs)
6159
self.assertEqualArray(expected, got)
62-
ep2 = torch.export.export(
60+
ep = torch.export.export(
6361
model,
6462
inputs,
6563
dynamic_shapes=(
@@ -68,11 +66,201 @@ def forward(self, x, cache):
6866
),
6967
strict=strict,
7068
)
71-
mod = ep2.module()
69+
mod = ep.module()
7270
got = mod(*inputs)
7371
self.assertEqualArray(expected, got)
7472

73+
class MyInterpreter(torch.fx.Interpreter):
74+
def call_function(self, target, args, kwargs):
75+
res = super().call_function(target, args, kwargs)
76+
return res
77+
78+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
79+
got = MyInterpreter(ep.module()).run(*args)
80+
self.assertEqualAny(expected, got)
81+
82+
@ignore_warnings(UserWarning)
83+
def test_export_mycache_list_cat(self):
84+
TreeContext = torch.utils._pytree.Context
85+
MappingKey = torch.utils._pytree.MappingKey
86+
KeyEntry = torch.utils._pytree.KeyEntry
87+
88+
class MyCache77:
89+
def __init__(self, key=None, value=None):
90+
self.key_cache = [key] if key is not None else []
91+
self.value_cache = [value] if value is not None else []
92+
93+
class ModelMyCache(torch.nn.Module):
94+
def forward(self, x, dc):
95+
y = (
96+
(
97+
torch.cat(dc.key_cache, axis=1) + torch.cat(dc.value_cache, axis=1)
98+
).reshape((-1, x.shape[1]))
99+
).transpose(1, 0)
100+
return x @ y
101+
102+
inputs = {
103+
"x": torch.randn(3, 8),
104+
"dc": MyCache77(torch.ones((3, 8, 3, 8)), torch.ones((3, 8, 3, 8))),
105+
}
106+
model = ModelMyCache()
107+
expected = model(**inputs)
108+
109+
def flatten_my_cache77(cache: MyCache77) -> Tuple[List[Any], TreeContext]:
110+
flat = [
111+
(k, getattr(cache, k))
112+
for k in ["key_cache", "value_cache"]
113+
if hasattr(cache, k)
114+
]
115+
return [f[1] for f in flat], [f[0] for f in flat]
116+
117+
def flatten_with_keys_my_cache77(
118+
d: Dict[Any, Any],
119+
) -> Tuple[List[Tuple[KeyEntry, Any]], TreeContext]:
120+
values, context = flatten_my_cache77(d)
121+
return [(MappingKey(k), v) for k, v in zip(context, values)], context
122+
123+
def unflatten_my_cache_77(
124+
values: List[Any], context: TreeContext, output_type=None
125+
) -> MyCache77:
126+
cache = MyCache77()
127+
values = dict(zip(context, values))
128+
for k, v in values.items():
129+
setattr(cache, k, v)
130+
return cache
131+
132+
torch.utils._pytree.register_pytree_node(
133+
MyCache77,
134+
flatten_my_cache77,
135+
unflatten_my_cache_77,
136+
serialized_type_name="MyCache77",
137+
flatten_with_keys_fn=flatten_with_keys_my_cache77,
138+
)
139+
140+
# DYN = torch.export.Dim.DYNAMIC
141+
ep = torch.export.export(model, (), kwargs=inputs)
142+
143+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
144+
got = torch.fx.Interpreter(ep.module()).run(*args)
145+
self.assertEqualAny(expected, got)
146+
147+
mod = ep.module()
148+
got = mod(**inputs)
149+
self.assertEqualArray(expected, got)
150+
151+
@ignore_warnings(UserWarning)
152+
def test_export_mycache_dict_cat(self):
153+
TreeContext = torch.utils._pytree.Context
154+
155+
class MyCache78:
156+
def __init__(self, key=None, value=None):
157+
self.key_cache = [key] if key is not None else []
158+
self.value_cache = [value] if value is not None else []
159+
160+
class ModelMyCache(torch.nn.Module):
161+
def forward(self, x, dc):
162+
y = (
163+
(
164+
torch.cat(dc.key_cache, axis=1) + torch.cat(dc.value_cache, axis=1)
165+
).reshape((-1, x.shape[1]))
166+
).transpose(1, 0)
167+
return x @ y
168+
169+
inputs = {
170+
"x": torch.randn(3, 8),
171+
"dc": MyCache78(torch.ones((3, 8, 3, 8)), torch.ones((3, 8, 3, 8))),
172+
}
173+
model = ModelMyCache()
174+
expected = model(**inputs)
175+
176+
def flatten_my_cache78(cache: MyCache78):
177+
dictionary = {
178+
"key_cache": cache.key_cache,
179+
"value_cache": cache.value_cache,
180+
}
181+
return torch.utils._pytree._dict_flatten(dictionary)
182+
183+
def flatten_with_keys_my_cache78(cache: MyCache78):
184+
dictionary = {
185+
"key_cache": cache.key_cache,
186+
"value_cache": cache.value_cache,
187+
}
188+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
189+
190+
def unflatten_my_cache_78(values, context: TreeContext, output_type=None) -> MyCache78:
191+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
192+
cache = MyCache78()
193+
for k, v in dictionary.items():
194+
setattr(cache, k, v)
195+
return cache
196+
197+
torch.utils._pytree.register_pytree_node(
198+
MyCache78,
199+
flatten_my_cache78,
200+
unflatten_my_cache_78,
201+
serialized_type_name="MyCache78",
202+
flatten_with_keys_fn=flatten_with_keys_my_cache78,
203+
)
204+
205+
# DYN = torch.export.Dim.DYNAMIC
206+
ep = torch.export.export(model, (), kwargs=inputs)
207+
208+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
209+
got = torch.fx.Interpreter(ep.module()).run(*args)
210+
self.assertEqualAny(expected, got)
211+
212+
mod = ep.module()
213+
got = mod(**inputs)
214+
self.assertEqualArray(expected, got)
215+
75216
@ignore_warnings(UserWarning)
217+
def test_export_dynamic_cache_cat(self):
218+
219+
class ModelDynamicCache(torch.nn.Module):
220+
def forward(self, x, dc):
221+
y = (
222+
(
223+
torch.cat(dc.key_cache, axis=1) + torch.cat(dc.value_cache, axis=1)
224+
).reshape((-1, x.shape[1]))
225+
).transpose(1, 0)
226+
return x @ y
227+
228+
inputs = {
229+
"x": torch.randn(3, 8),
230+
"dc": make_dynamic_cache(
231+
[(torch.ones((3, 8, 3, 8)), (torch.ones((3, 8, 3, 8)) * 2))]
232+
),
233+
}
234+
model = ModelDynamicCache()
235+
expected = model(**inputs)
236+
237+
# DYN = torch.export.Dim.DYNAMIC
238+
NOBYPASS = int(os.environ.get("NOBYBASS", "0"))
239+
if NOBYPASS:
240+
ep = torch.export.export(model, (), kwargs=inputs)
241+
242+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
243+
got = torch.fx.Interpreter(ep.module()).run(*args)
244+
self.assertEqualAny(expected, got)
245+
246+
mod = ep.module()
247+
got = mod(**inputs)
248+
self.assertEqualArray(expected, got)
249+
return
250+
251+
with bypass_export_some_errors(patch_transformers=True):
252+
ep = torch.export.export(model, (), kwargs=inputs)
253+
254+
args, _spec = torch.utils._pytree.tree_flatten(inputs)
255+
got = torch.fx.Interpreter(ep.module()).run(*args)
256+
self.assertEqualAny(expected, got)
257+
258+
mod = ep.module()
259+
got = mod(**inputs)
260+
self.assertEqualArray(expected, got)
261+
262+
@ignore_warnings(UserWarning)
263+
@requires_torch("2.9")
76264
def test_phi2_export_module(self):
77265
data = get_untrained_model_with_inputs("microsoft/phi-2")
78266
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -100,6 +288,7 @@ def test_phi2_export_module(self):
100288
dynamic_shapes=dyn_shapes,
101289
strict=False, # True works but then the it fails during the execution
102290
)
291+
# ep = ep.run_decompositions()
103292
mod = ep.module()
104293
inputs_copied = copy.deepcopy(inputs)
105294
self.assertEqual(
@@ -108,15 +297,8 @@ def test_phi2_export_module(self):
108297
got = mod(**inputs_copied)
109298
self.assertEqualAny(expected, got)
110299

111-
inputs_copied = copy.deepcopy(inputs)
112-
self.assertEqual(
113-
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
114-
)
115-
mod = ep.module()
116-
got = mod(**inputs_copied)
117-
self.assertEqualAny(expected, got)
118-
119300
@ignore_warnings(UserWarning)
301+
@requires_torch("2.9")
120302
def test_phi2_export_interpreter(self):
121303
data = get_untrained_model_with_inputs("microsoft/phi-2")
122304
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -144,6 +326,7 @@ def test_phi2_export_interpreter(self):
144326
dynamic_shapes=dyn_shapes,
145327
strict=False, # True works but then the it fails during the execution
146328
)
329+
# ep = ep.run_decompositions()
147330

148331
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
149332
# CustomTracer.remove_unnecessary_slices(ep.graph)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,19 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
100100
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
101101
)
102102

103-
# DynamicCache
103+
# DynamicCache serialization is different in transformers and does not
104+
# play way with torch.export.export.
105+
# see test test_export_dynamic_cache_cat with NOBYPASS=1
106+
# :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
107+
# This is caused by this line:
108+
# torch.fx._pytree.register_pytree_flatten_spec(
109+
# DynamicCache, _flatten_dynamic_cache_for_fx)
110+
# so we remove it anyway
111+
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES:
112+
if verbose:
113+
print("[_register_cache_serialization] DynamicCache is unregistered first.")
114+
_unregister(DynamicCache)
115+
104116
unregistered_dynamic_cache = True
105117
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
106118
if verbose > 1:

0 commit comments

Comments
 (0)