Skip to content

Commit 57c7215

Browse files
committed
fix
1 parent 8161a2d commit 57c7215

File tree

2 files changed

+120
-59
lines changed

2 files changed

+120
-59
lines changed

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
66
bypass_export_some_errors,
77
)
8+
from transformers.modeling_outputs import BaseModelOutput
89

910

1011
class TestPatchSerialization(ExtTestCase):
@@ -83,6 +84,21 @@ def forward(self, cache):
8384
with bypass_export_some_errors():
8485
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
8586

87+
@ignore_warnings(UserWarning)
88+
def test_base_model_output(self):
89+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
90+
with bypass_export_some_errors():
91+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
92+
self.assertEqual(
93+
"#1[T1s4x4x4]",
94+
self.string_type(flat, with_shape=True),
95+
)
96+
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
97+
self.assertEqual(
98+
self.string_type(bo, with_shape=True, with_min_max=True),
99+
self.string_type(bo2, with_shape=True, with_min_max=True),
100+
)
101+
86102

87103
if __name__ == "__main__":
88104
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 104 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,20 @@
11
import pprint
22
from typing import Any, Dict, List, Set, Tuple
3+
import packaging.version as pv
34
import optree
45
import torch
56
import transformers
6-
import packaging.version as pv
7+
from transformers.cache_utils import DynamicCache, MambaCache, EncoderDecoderCache
8+
from transformers.modeling_outputs import BaseModelOutput
79

810

911
PATCH_OF_PATCHES: Set[Any] = set()
1012

1113

1214
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
13-
# Cache serialization: to be moved into appropriate packages
14-
15-
try:
16-
from transformers.cache_utils import DynamicCache
17-
except ImportError:
18-
DynamicCache = None
19-
20-
try:
21-
from transformers.cache_utils import MambaCache
22-
except ImportError:
23-
MambaCache = None
24-
25-
try:
26-
from transformers.cache_utils import EncoderDecoderCache
27-
except ImportError:
28-
EncoderDecoderCache = None
29-
3015
# MambaCache
3116
unregistered_mamba_cache = True
32-
if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES:
17+
if MambaCache in torch.utils._pytree.SUPPORTED_NODES:
3318
if verbose > 1:
3419
print(f"[_register_cache_serialization] {MambaCache} already registered")
3520
# It is already registered because bypass_export_some_errors was called
@@ -82,6 +67,26 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
8267
# To avoid doing it multiple times.
8368
PATCH_OF_PATCHES.add(DynamicCache)
8469

70+
# BaseModelOutput serialization is incomplete.
71+
# It does not include dynamic shapes mapping.
72+
if BaseModelOutput in torch.fx._pytree.SUPPORTED_NODES and not PATCH_OF_PATCHES:
73+
if verbose:
74+
print(
75+
"[_register_cache_serialization] BaseModelOutput "
76+
"is unregistered and registered first."
77+
)
78+
_unregister(BaseModelOutput)
79+
torch.utils._pytree.register_pytree_node(
80+
BaseModelOutput,
81+
flatten_base_model_output,
82+
unflatten_base_model_output,
83+
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
84+
flatten_with_keys_fn=flatten_with_keys_base_model_output,
85+
)
86+
87+
# To avoid doing it multiple times.
88+
PATCH_OF_PATCHES.add(BaseModelOutput)
89+
8590
unregistered_dynamic_cache = True
8691
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
8792
if verbose > 1:
@@ -123,7 +128,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
123128
# within a section already calling bypass_export_some_errors or transformers
124129
# has updated its code to do it.
125130
# No need to register and unregister then.
126-
unregistered_mamba_cache = False
131+
unregistered_encode_decode_cache = False
127132
else:
128133
if verbose:
129134
print("[_register_cache_serialization] register EncoderDecoderCache")
@@ -135,10 +140,32 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
135140
flatten_with_keys_fn=flatten_with_keys_encoder_decoder_cache,
136141
)
137142

143+
# BaseModelOutput
144+
unregistered_base_model_output = True
145+
if BaseModelOutput is not None and BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES:
146+
if verbose > 1:
147+
print(f"[_register_cache_serialization] {BaseModelOutput} already registered")
148+
# It is already registered because bypass_export_some_errors was called
149+
# within a section already calling bypass_export_some_errors or transformers
150+
# has updated its code to do it.
151+
# No need to register and unregister then.
152+
unregistered_base_model_output = False
153+
else:
154+
if verbose:
155+
print("[_register_cache_serialization] register BaseModelOutput")
156+
torch.utils._pytree.register_pytree_node(
157+
BaseModelOutput,
158+
flatten_encoder_decoder_cache,
159+
unflatten_encoder_decoder_cache,
160+
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
161+
flatten_with_keys_fn=flatten_with_keys_base_model_output,
162+
)
163+
138164
return dict(
139165
DynamicCache=unregistered_dynamic_cache,
140166
MambaCache=unregistered_mamba_cache,
141167
EncoderDecoderCache=unregistered_encode_decode_cache,
168+
BaseModelOutput=unregistered_base_model_output,
142169
)
143170

144171

@@ -167,20 +194,11 @@ def _unregister(cls: type, verbose: int = 0):
167194

168195

169196
def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
170-
if undo.get("MambaCache", False):
171-
_unregister(transformers.cache_utils.MambaCache, verbose)
172-
elif verbose > 1:
173-
print("[_unregister_cache_serialization] skip unregister MambaCache")
174-
175-
if undo.get("DynamicCache", False):
176-
_unregister(transformers.cache_utils.DynamicCache, verbose)
177-
elif verbose > 1:
178-
print("[_unregister_cache_serialization] skip unregister DynamicCache")
179-
180-
if undo.get("EncoderDecoderCache", False):
181-
_unregister(transformers.cache_utils.EncoderDecoderCache, verbose)
182-
elif verbose > 1:
183-
print("[_unregister_cache_serialization] skip unregister EncoderDecoderCache")
197+
for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]:
198+
if undo.get(cls.__name__, False):
199+
_unregister(cls, verbose)
200+
elif verbose > 1:
201+
print(f"[_unregister_cache_serialization] skip unregister {cls.__name__}")
184202

185203

186204
############
@@ -205,7 +223,7 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
205223
# dtype=dtype,
206224
# )
207225
def flatten_mamba_cache(
208-
mamba_cache: transformers.cache_utils.MambaCache,
226+
mamba_cache: MambaCache,
209227
) -> Tuple[List[Any], torch.utils._pytree.Context]:
210228
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
211229
flat = [
@@ -224,10 +242,8 @@ def flatten_mamba_cache(
224242

225243

226244
def unflatten_mamba_cache(
227-
values: List[Any],
228-
context: torch.utils._pytree.Context,
229-
output_type=None,
230-
) -> transformers.cache_utils.MambaCache:
245+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
246+
) -> MambaCache:
231247
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
232248
conv_states, ssm_states = values
233249

@@ -258,12 +274,12 @@ def __init__(self):
258274
return cache
259275

260276

261-
def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[
277+
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
262278
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
263279
torch.utils._pytree.Context,
264280
]:
265281
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
266-
values, context = flatten_mamba_cache(d)
282+
values, context = flatten_mamba_cache(cache)
267283
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
268284

269285

@@ -273,7 +289,7 @@ def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[
273289

274290

275291
def flatten_dynamic_cache(
276-
dynamic_cache: transformers.cache_utils.DynamicCache,
292+
dynamic_cache: DynamicCache,
277293
) -> Tuple[List[Any], torch.utils._pytree.Context]:
278294
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
279295
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
@@ -287,11 +303,8 @@ def flatten_dynamic_cache(
287303

288304

289305
def flatten_with_keys_dynamic_cache(
290-
dynamic_cache: transformers.cache_utils.DynamicCache,
291-
) -> Tuple[
292-
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
293-
torch.utils._pytree.Context,
294-
]:
306+
dynamic_cache: DynamicCache,
307+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
295308
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
296309
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
297310
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
@@ -300,10 +313,8 @@ def flatten_with_keys_dynamic_cache(
300313

301314

302315
def unflatten_dynamic_cache(
303-
values: List[Any],
304-
context: torch.utils._pytree.Context,
305-
output_type=None,
306-
) -> transformers.cache_utils.DynamicCache:
316+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
317+
) -> DynamicCache:
307318
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
308319
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
309320
assert output_type is None, f"output_type={output_type} not supported"
@@ -322,7 +333,7 @@ def unflatten_dynamic_cache(
322333

323334

324335
def flatten_encoder_decoder_cache(
325-
ec_cache: transformers.cache_utils.DynamicCache,
336+
ec_cache: EncoderDecoderCache,
326337
) -> Tuple[List[Any], torch.utils._pytree.Context]:
327338
"""
328339
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
@@ -335,9 +346,7 @@ def flatten_encoder_decoder_cache(
335346
return torch.utils._pytree._dict_flatten(dictionary)
336347

337348

338-
def flatten_with_keys_encoder_decoder_cache(
339-
ec_cache: transformers.cache_utils.DynamicCache,
340-
) -> Tuple[
349+
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
341350
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
342351
torch.utils._pytree.Context,
343352
]:
@@ -353,10 +362,46 @@ def flatten_with_keys_encoder_decoder_cache(
353362

354363

355364
def unflatten_encoder_decoder_cache(
356-
values: List[Any],
357-
context: torch.utils._pytree.Context,
358-
output_type=None,
359-
) -> transformers.cache_utils.EncoderDecoderCache:
365+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
366+
) -> EncoderDecoderCache:
360367
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
361368
dictionary = torch.utils._pytree._dict_unflatten(values, context)
362369
return transformers.cache_utils.EncoderDecoderCache(**dictionary)
370+
371+
372+
#################
373+
# BaseModelOutput
374+
#################
375+
376+
377+
def flatten_base_model_output(
378+
bo: BaseModelOutput,
379+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
380+
"""
381+
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
382+
with python objects.
383+
"""
384+
return list(bo.values()), list(bo.keys())
385+
386+
387+
def flatten_with_keys_base_model_output(
388+
bo: BaseModelOutput,
389+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
390+
"""
391+
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
392+
with python objects.
393+
"""
394+
values, context = flatten_dynamic_cache(bo)
395+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
396+
397+
398+
def unflatten_base_model_output(
399+
values: List[Any],
400+
context: torch.utils._pytree.Context,
401+
output_type=None,
402+
) -> BaseModelOutput:
403+
"""
404+
Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
405+
from python objects.
406+
"""
407+
return BaseModelOutput(**dict(zip(context, values)))

0 commit comments

Comments
 (0)