Skip to content

Commit 6137993

Browse files
committed
fix issues
1 parent cf64fb0 commit 6137993

File tree

6 files changed

+145
-133
lines changed

6 files changed

+145
-133
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,42 @@ def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
112112
data = get_untrained_model_with_inputs(mid, verbose=1)
113113
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
114114
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
115+
Dim = torch.export.Dim
116+
self.maxDiff = None
117+
self.assertIn("{0:Dim(batch),1:Dim(seq_length)}", self.string_type(ds))
118+
self.assertEqualAny(
119+
{
120+
"decoder_input_ids": {
121+
0: Dim("batch", min=1, max=1024),
122+
1: Dim("seq_length", min=1, max=4096),
123+
},
124+
"cache_position": {0: Dim("seq_length", min=1, max=4096)},
125+
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
126+
"past_key_values": [
127+
[
128+
[
129+
{0: Dim("batch", min=1, max=1024)},
130+
{0: Dim("batch", min=1, max=1024)},
131+
],
132+
[
133+
{0: Dim("batch", min=1, max=1024)},
134+
{0: Dim("batch", min=1, max=1024)},
135+
],
136+
],
137+
[
138+
[
139+
{0: Dim("batch", min=1, max=1024)},
140+
{0: Dim("batch", min=1, max=1024)},
141+
],
142+
[
143+
{0: Dim("batch", min=1, max=1024)},
144+
{0: Dim("batch", min=1, max=1024)},
145+
],
146+
],
147+
],
148+
},
149+
ds,
150+
)
115151
model(**inputs)
116152
self.assertEqual(
117153
"#1[T1r3]",
@@ -125,7 +161,16 @@ def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
125161
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
126162
self.string_type(flat),
127163
)
128-
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
164+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
165+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
166+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
167+
self.assertIsInstance(flat, list)
168+
self.assertIsInstance(flat[0], torch.Tensor)
169+
self.assertEqual(
170+
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
171+
self.string_type(flat),
172+
)
173+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
129174

130175
@hide_stdout()
131176
def test_get_untrained_model_with_inputs_imagetext2text_generation(self):

onnx_diagnostic/ext_test_case.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,9 @@ def assertEqualAny(
910910
elif hasattr(expected, "shape"):
911911
self.assertEqual(type(expected), type(value), msg=msg)
912912
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
913+
elif expected.__class__.__name__ in ("Dim", "_Dim"):
914+
self.assertEqual(type(expected), type(value), msg=msg)
915+
self.assertEqual(expected.__name__, value.__name__, msg=msg)
913916
else:
914917
raise AssertionError(
915918
f"Comparison not implemented for types {type(expected)} and {type(value)}"

onnx_diagnostic/helpers/helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def string_type(
249249
limit=limit,
250250
)
251251
s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items())
252+
if all(isinstance(k, int) for k in obj):
253+
return f"{{{s}}}"
252254
return f"dict({s})"
253255
# arrat
254256
if isinstance(obj, np.ndarray):
@@ -279,7 +281,7 @@ def string_type(
279281
if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
280282
return "DerivedDim"
281283
if isinstance(obj, torch.export.dynamic_shapes._Dim):
282-
return "Dim"
284+
return f"Dim({obj.__name__})"
283285
if isinstance(obj, torch.SymInt):
284286
return "SymInt"
285287
if isinstance(obj, torch.SymFloat):
@@ -355,6 +357,11 @@ def string_type(
355357
if isinstance(obj, slice):
356358
return "slice"
357359

360+
if obj == torch.export.Dim.DYNAMIC:
361+
return "DYNAMIC"
362+
if obj == torch.export.Dim.AUTO:
363+
return "AUTO"
364+
358365
# others classes
359366

360367
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
3636
original = cls._PATCHED_CLASS_
3737
methods = cls._PATCHES_
3838
if verbose:
39-
print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
39+
print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
4040

4141
keep = {n: getattr(original, n, None) for n in methods}
4242
for n in methods:
@@ -69,7 +69,7 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
6969
for cls, methods in info.items():
7070
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
7171
if verbose:
72-
print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
72+
print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
7373
original = cls._PATCHED_CLASS_
7474
for n, v in methods.items():
7575
if v is None:

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 85 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,58 @@
11
import pprint
2-
from typing import Any, Dict, List, Set, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
33
import packaging.version as pv
44
import optree
55
import torch
66
import transformers
77
from transformers.cache_utils import DynamicCache, MambaCache, EncoderDecoderCache
88
from transformers.modeling_outputs import BaseModelOutput
9+
from ..helpers import string_type
910

1011

1112
PATCH_OF_PATCHES: Set[Any] = set()
1213

1314

15+
def _register_class_serialization(
16+
cls,
17+
f_flatten: Callable,
18+
f_unflatten: Callable,
19+
f_flatten_with_keys: Callable,
20+
f_check: Optional[Callable] = None,
21+
verbose: int = 0,
22+
) -> bool:
23+
if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
24+
return False
25+
26+
if verbose:
27+
print(f"[_register_cache_serialization] register {cls}")
28+
torch.utils._pytree.register_pytree_node(
29+
cls,
30+
f_flatten,
31+
f_unflatten,
32+
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
33+
flatten_with_keys_fn=f_flatten_with_keys,
34+
)
35+
if pv.Version(torch.__version__) < pv.Version("2.7"):
36+
if verbose:
37+
print(
38+
f"[_register_cache_serialization] "
39+
f"register {cls} for torch=={torch.__version__}"
40+
)
41+
torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
42+
43+
# check
44+
if f_check:
45+
inst = f_check()
46+
values, spec = torch.utils._pytree.tree_flatten(inst)
47+
restored = torch.utils._pytree.tree_unflatten(values, spec)
48+
assert string_type(inst, with_shape=True) == string_type(restored, with_shape=True), (
49+
f"Issue with registration of class {cls} "
50+
f"inst={string_type(inst, with_shape=True)}, "
51+
f"restored={string_type(restored, with_shape=True)}"
52+
)
53+
return True
54+
55+
1456
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
1557
# DynamicCache serialization is different in transformers and does not
1658
# play way with torch.export.export.
@@ -28,26 +70,20 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
2870
):
2971
if verbose:
3072
print(
31-
"[_register_cache_serialization] DynamicCache "
32-
"is unregistered and registered first."
73+
f"[_fix_registration] DynamicCache is unregistered and "
74+
f"registered first for transformers=={transformers.__version__}"
3375
)
3476
_unregister(DynamicCache, verbose=verbose)
35-
torch.utils._pytree.register_pytree_node(
77+
_register_class_serialization(
3678
DynamicCache,
3779
flatten_dynamic_cache,
3880
unflatten_dynamic_cache,
39-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
40-
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
81+
flatten_with_keys_dynamic_cache,
82+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
83+
verbose=verbose,
4184
)
4285
if verbose:
43-
print(
44-
"[_register_cache_serialization] DynamicCache "
45-
"unregistered and registered done."
46-
)
47-
if pv.Version(torch.__version__) < pv.Version("2.7"):
48-
torch.fx._pytree.register_pytree_flatten_spec(
49-
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
50-
)
86+
print("[_fix_registration] DynamicCache done.")
5187
# To avoid doing it multiple times.
5288
PATCH_OF_PATCHES.add(DynamicCache)
5389

@@ -59,120 +95,52 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
5995
):
6096
if verbose:
6197
print(
62-
"[_register_cache_serialization] BaseModelOutput "
63-
"is unregistered and registered first."
98+
f"[_fix_registration] BaseModelOutput is unregistered and "
99+
f"registered first for transformers=={transformers.__version__}"
64100
)
65101
_unregister(BaseModelOutput, verbose=verbose)
66-
torch.utils._pytree.register_pytree_node(
102+
_register_class_serialization(
67103
BaseModelOutput,
68104
flatten_base_model_output,
69105
unflatten_base_model_output,
70-
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
71-
flatten_with_keys_fn=flatten_with_keys_base_model_output,
106+
flatten_with_keys_base_model_output,
107+
verbose=verbose,
72108
)
73109
if verbose:
74-
print(
75-
"[_register_cache_serialization] BaseModelOutput "
76-
"unregistered and registered done."
77-
)
110+
print("[_fix_registration] BaseModelOutput done.")
78111

79112
# To avoid doing it multiple times.
80113
PATCH_OF_PATCHES.add(BaseModelOutput)
81114

82-
unregistered_dynamic_cache = True
83-
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
84-
if verbose > 1:
85-
print(f"[_register_cache_serialization] {DynamicCache} already registered")
86-
unregistered_dynamic_cache = False
87-
else:
88-
if verbose:
89-
print("[_register_cache_serialization] register DynamicCache")
90-
torch.utils._pytree.register_pytree_node(
91-
DynamicCache,
92-
flatten_dynamic_cache,
93-
unflatten_dynamic_cache,
94-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
95-
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
96-
)
97-
if pv.Version(torch.__version__) < pv.Version("2.7"):
98-
torch.fx._pytree.register_pytree_flatten_spec(
99-
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
100-
)
101-
102-
# check
103-
from ..helpers.cache_helper import make_dynamic_cache
104-
105-
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
106-
values, spec = torch.utils._pytree.tree_flatten(cache)
107-
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
108-
# torch.fx._pytree.tree_flatten(cache)
109-
assert len(cache2.key_cache) == 1
110-
111-
# BaseModelOutput
112-
unregistered_base_model_output = True
113-
if BaseModelOutput is not None and BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES:
114-
if verbose > 1:
115-
print(f"[_register_cache_serialization] {BaseModelOutput} already registered")
116-
# It is already registered because bypass_export_some_errors was called
117-
# within a section already calling bypass_export_some_errors or transformers
118-
# has updated its code to do it.
119-
# No need to register and unregister then.
120-
unregistered_base_model_output = False
121-
else:
122-
if verbose:
123-
print("[_register_cache_serialization] register BaseModelOutput")
124-
torch.utils._pytree.register_pytree_node(
125-
BaseModelOutput,
126-
flatten_encoder_decoder_cache,
127-
unflatten_encoder_decoder_cache,
128-
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
129-
flatten_with_keys_fn=flatten_with_keys_base_model_output,
130-
)
131-
132-
# MambaCache
133-
unregistered_mamba_cache = True
134-
if MambaCache in torch.utils._pytree.SUPPORTED_NODES:
135-
if verbose > 1:
136-
print(f"[_register_cache_serialization] {MambaCache} already registered")
137-
# It is already registered because bypass_export_some_errors was called
138-
# within a section already calling bypass_export_some_errors or transformers
139-
# has updated its code to do it.
140-
# No need to register and unregister then.
141-
unregistered_mamba_cache = False
142-
else:
143-
if verbose:
144-
print("[_register_cache_serialization] register MambaCache")
145-
torch.utils._pytree.register_pytree_node(
146-
MambaCache,
147-
flatten_mamba_cache,
148-
unflatten_mamba_cache,
149-
serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}",
150-
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
151-
)
152-
153-
# EncoderDecoderCache
154-
unregistered_encode_decode_cache = True
155-
if (
156-
EncoderDecoderCache is not None
157-
and EncoderDecoderCache in torch.utils._pytree.SUPPORTED_NODES
158-
):
159-
if verbose > 1:
160-
print(f"[_register_cache_serialization] {EncoderDecoderCache} already registered")
161-
# It is already registered because bypass_export_some_errors was called
162-
# within a section already calling bypass_export_some_errors or transformers
163-
# has updated its code to do it.
164-
# No need to register and unregister then.
165-
unregistered_encode_decode_cache = False
166-
else:
167-
if verbose:
168-
print("[_register_cache_serialization] register EncoderDecoderCache")
169-
torch.utils._pytree.register_pytree_node(
170-
EncoderDecoderCache,
171-
flatten_encoder_decoder_cache,
172-
unflatten_encoder_decoder_cache,
173-
serialized_type_name=f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}",
174-
flatten_with_keys_fn=flatten_with_keys_encoder_decoder_cache,
175-
)
115+
unregistered_dynamic_cache = _register_class_serialization(
116+
DynamicCache,
117+
flatten_dynamic_cache,
118+
unflatten_dynamic_cache,
119+
flatten_with_keys_dynamic_cache,
120+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
121+
verbose=verbose,
122+
)
123+
unregistered_base_model_output = _register_class_serialization(
124+
BaseModelOutput,
125+
flatten_base_model_output,
126+
unflatten_base_model_output,
127+
flatten_with_keys_base_model_output,
128+
verbose=verbose,
129+
)
130+
unregistered_encode_decode_cache = _register_class_serialization(
131+
EncoderDecoderCache,
132+
flatten_encoder_decoder_cache,
133+
unflatten_encoder_decoder_cache,
134+
flatten_with_keys_encoder_decoder_cache,
135+
verbose=verbose,
136+
)
137+
unregistered_mamba_cache = _register_class_serialization(
138+
MambaCache,
139+
flatten_mamba_cache,
140+
unflatten_mamba_cache,
141+
flatten_with_keys_mamba_cache,
142+
verbose=verbose,
143+
)
176144

177145
return dict(
178146
DynamicCache=unregistered_dynamic_cache,
@@ -213,8 +181,6 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
213181
for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]:
214182
if undo.get(cls.__name__, False):
215183
_unregister(cls, verbose)
216-
elif verbose > 1:
217-
print(f"[_unregister_cache_serialization] skip unregister {cls.__name__}")
218184

219185

220186
############

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,7 @@ def empty(value: Any) -> bool:
2323

2424

2525
def _ds_clean(v):
26-
return (
27-
str(v)
28-
.replace(",min=None", "")
29-
.replace(",max=None", "")
30-
.replace(",_factory=True", "")
31-
.replace("<class 'onnx_diagnostic.torch_models.hghub.model_inputs.", "")
32-
.replace("'>", "")
33-
.replace("_DimHint(type=<_DimHintType.DYNAMIC: 3>)", "DYNAMIC")
34-
.replace("_DimHint(type=<_DimHintType.AUTO: 3>)", "AUTO")
35-
)
26+
return string_type(v)
3627

3728

3829
def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]:

0 commit comments

Comments
 (0)