Skip to content

Commit 26ee1ce

Browse files
committed
add custom
1 parent 7b8eb85 commit 26ee1ce

File tree

8 files changed

+353
-116
lines changed

8 files changed

+353
-116
lines changed

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import unittest
22
import torch
3+
from transformers.modeling_outputs import BaseModelOutput
34
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
45
from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache
56
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
67
bypass_export_some_errors,
78
)
8-
from transformers.modeling_outputs import BaseModelOutput
9+
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy
910

1011

1112
class TestPatchSerialization(ExtTestCase):
1213
@ignore_warnings(UserWarning)
13-
def test_flatten_encoder_decoder_cache(self):
14+
def test_encoder_decoder_cache_flatten(self):
1415
cache = make_encoder_decoder_cache(
1516
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
1617
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
@@ -28,7 +29,17 @@ def test_flatten_encoder_decoder_cache(self):
2829
)
2930

3031
@ignore_warnings(UserWarning)
31-
def test_export_encoder_decoder_cache(self):
32+
def test_encoder_decoder_cache_deepcopy(self):
33+
cache = make_encoder_decoder_cache(
34+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
35+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
36+
)
37+
with bypass_export_some_errors():
38+
cache2 = torch_deepcopy([cache])
39+
self.assertEqualAny([cache], cache2)
40+
41+
@ignore_warnings(UserWarning)
42+
def test_encoder_decoder_cache_export(self):
3243
class Model(torch.nn.Module):
3344
def forward(self, cache):
3445
return cache.self_attention_cache.key_cache[0]
@@ -53,7 +64,7 @@ def forward(self, cache):
5364
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
5465

5566
@ignore_warnings(UserWarning)
56-
def test_flatten_dynamic_cache(self):
67+
def test_dynamic_cache_flatten(self):
5768
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
5869
with bypass_export_some_errors():
5970
flat, _spec = torch.utils._pytree.tree_flatten(cache)
@@ -68,7 +79,7 @@ def test_flatten_dynamic_cache(self):
6879
)
6980

7081
@ignore_warnings(UserWarning)
71-
def test_export_dynamic_cache(self):
82+
def test_dynamic_cache_export(self):
7283
class Model(torch.nn.Module):
7384
def forward(self, cache):
7485
return cache.key_cache[0]
@@ -85,7 +96,33 @@ def forward(self, cache):
8596
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
8697

8798
@ignore_warnings(UserWarning)
88-
def test_base_model_output(self):
99+
def test_dynamic_cache_deepcopy(self):
100+
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
101+
with bypass_export_some_errors():
102+
cache2 = torch_deepcopy([cache])
103+
self.assertEqualAny([cache], cache2)
104+
105+
@ignore_warnings(UserWarning)
106+
def test_base_model_output_deepcopy(self):
107+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
108+
self.assertEqual(bo.__class__.__name__, "BaseModelOutput")
109+
with bypass_export_some_errors():
110+
bo2 = torch_deepcopy([bo])
111+
self.assertIsInstance(bo2, list)
112+
self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput")
113+
self.assertEqualAny([bo], bo2)
114+
115+
@ignore_warnings(UserWarning)
116+
def test_base_model_output_string_type(self):
117+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
118+
with bypass_export_some_errors():
119+
self.assertEqual(
120+
"BaseModelOutput(last_hidden_state:T1s4x4x4)",
121+
self.string_type(bo, with_shape=True),
122+
)
123+
124+
@ignore_warnings(UserWarning)
125+
def test_base_model_output_flatten(self):
89126
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
90127
with bypass_export_some_errors():
91128
flat, _spec = torch.utils._pytree.tree_flatten(bo)
@@ -100,7 +137,7 @@ def test_base_model_output(self):
100137
)
101138

102139
@ignore_warnings(UserWarning)
103-
def test_export_base_model_output(self):
140+
def test_base_model_output_export(self):
104141
class Model(torch.nn.Module):
105142
def forward(self, cache):
106143
return cache.last_hidden_state[0]

onnx_diagnostic/_command_lines_parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ def _cmd_validate(argv: List[Any]):
315315
for k, v in data["dynamic_shapes"].items():
316316
print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
317317
else:
318+
# Let's skip any invalid combination if known to be unsupported
319+
if "onnx" not in args.export and "custom" not in args.export and args.opt:
320+
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
321+
return
318322
summary, _data = validate_model(
319323
model_id=args.mid,
320324
task=args.task,

onnx_diagnostic/ext_test_case.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,18 @@ def assertEqual(self, expected: Any, value: Any, msg: str = ""):
887887
def assertEqualAny(
888888
self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
889889
):
890-
if isinstance(expected, (tuple, list, dict)):
890+
if expected.__class__.__name__ == "BaseModelOutput":
891+
self.assertEqual(type(expected), type(value), msg=msg)
892+
self.assertEqual(len(expected), len(value), msg=msg)
893+
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
894+
self.assertEqualAny(
895+
{k: v for k, v in expected.items()}, # noqa: C416
896+
{k: v for k, v in value.items()}, # noqa: C416
897+
atol=atol,
898+
rtol=rtol,
899+
msg=msg,
900+
)
901+
elif isinstance(expected, (tuple, list, dict)):
891902
self.assertIsInstance(value, type(expected), msg=msg)
892903
self.assertEqual(len(expected), len(value), msg=msg)
893904
if isinstance(expected, dict):
@@ -898,13 +909,23 @@ def assertEqualAny(
898909
for e, g in zip(expected, value):
899910
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
900911
elif expected.__class__.__name__ == "DynamicCache":
912+
self.assertEqual(type(expected), type(value), msg=msg)
901913
atts = ["key_cache", "value_cache"]
902914
self.assertEqualAny(
903915
{k: expected.__dict__.get(k, None) for k in atts},
904916
{k: value.__dict__.get(k, None) for k in atts},
905917
atol=atol,
906918
rtol=rtol,
907919
)
920+
elif expected.__class__.__name__ == "EncoderDecoderCache":
921+
self.assertEqual(type(expected), type(value), msg=msg)
922+
atts = ["self_attention_cache", "cross_attention_cache"]
923+
self.assertEqualAny(
924+
{k: expected.__dict__.get(k, None) for k in atts},
925+
{k: value.__dict__.get(k, None) for k in atts},
926+
atol=atol,
927+
rtol=rtol,
928+
)
908929
elif isinstance(expected, (int, float, str)):
909930
self.assertEqual(expected, value, msg=msg)
910931
elif hasattr(expected, "shape"):

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def make_feeds(
1111
inputs: Any,
1212
use_numpy: bool = False,
1313
copy: bool = False,
14+
check_flatten: bool = True,
1415
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
1516
"""
1617
Serializes the inputs to produce feeds expected
@@ -21,17 +22,20 @@ def make_feeds(
2122
:param use_numpy: if True, converts torch tensors into numpy arrays
2223
:param copy: a copy is made, this should be the case if the inputs is ingested
2324
by ``OrtValue``
25+
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
26+
returns the same number of outputs
2427
:return: feeds dictionary
2528
"""
2629
flat = flatten_object(inputs, drop_keys=True)
2730
assert (
28-
not all(isinstance(obj, torch.Tensor) for obj in flat)
31+
not check_flatten
32+
or not all(isinstance(obj, torch.Tensor) for obj in flat)
2933
or not is_cache_dynamic_registered(fast=True)
3034
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
3135
), (
3236
f"Unexpected number of flattened objects, "
33-
f"{string_type(flat, with_shape=True, limit=20)} != "
34-
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}"
37+
f"{string_type(flat, with_shape=True)} != "
38+
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
3539
)
3640
if use_numpy:
3741
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,10 @@ def torch_deepcopy(value: Any) -> Any:
351351
if isinstance(value, set):
352352
return {torch_deepcopy(v) for v in value}
353353
if isinstance(value, dict):
354-
return {k: torch_deepcopy(v) for k, v in value.items()}
354+
if type(value) is dict:
355+
return {k: torch_deepcopy(v) for k, v in value.items()}
356+
# for BaseModelOutput
357+
return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()})
355358
if isinstance(value, np.ndarray):
356359
return value.copy()
357360
if hasattr(value, "clone"):

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ def replacement_before_exporting(args: Any) -> Any:
424424
return None
425425
if isinstance(args, (int, float)):
426426
return args
427+
if type(args) not in {dict, tuple, list}:
428+
# BaseModelOutput is a dict
429+
return args
427430
if isinstance(args, dict):
428431
return {k: replacement_before_exporting(v) for k, v in args.items()}
429432
if isinstance(args, tuple):

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -188,36 +188,13 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
188188
############
189189

190190

191-
# self.conv_states: torch.Tensor = torch.zeros(
192-
# config.num_hidden_layers,
193-
# self.max_batch_size,
194-
# self.intermediate_size,
195-
# self.conv_kernel_size,
196-
# device=device,
197-
# dtype=dtype,
198-
# )
199-
# self.ssm_states: torch.Tensor = torch.zeros(
200-
# config.num_hidden_layers,
201-
# self.max_batch_size,
202-
# self.intermediate_size,
203-
# self.ssm_state_size,
204-
# device=device,
205-
# dtype=dtype,
206-
# )
207191
def flatten_mamba_cache(
208192
mamba_cache: MambaCache,
209193
) -> Tuple[List[Any], torch.utils._pytree.Context]:
210194
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
211195
flat = [
212196
(k, getattr(mamba_cache, k))
213-
for k in [
214-
# "max_batch_size", # new in transformers==4.47
215-
# "intermediate_size",
216-
# "ssm_state_size",
217-
# "conv_kernel_size",
218-
"conv_states",
219-
"ssm_states",
220-
]
197+
for k in ["conv_states", "ssm_states"]
221198
if hasattr(mamba_cache, k)
222199
]
223200
return [f[1] for f in flat], [f[0] for f in flat]
@@ -242,8 +219,6 @@ def __init__(self):
242219
self.conv_kernel = conv_states.shape[3]
243220
self.num_hidden_layers = conv_states.shape[0]
244221

245-
from transformers.cache_utils import MambaCache
246-
247222
cache = MambaCache(
248223
_config(),
249224
max_batch_size=1,
@@ -348,7 +323,7 @@ def unflatten_encoder_decoder_cache(
348323
) -> EncoderDecoderCache:
349324
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
350325
dictionary = torch.utils._pytree._dict_unflatten(values, context)
351-
return transformers.cache_utils.EncoderDecoderCache(**dictionary)
326+
return EncoderDecoderCache(**dictionary)
352327

353328

354329
#################

0 commit comments

Comments
 (0)