Skip to content

Commit c284e7c

Browse files
committed
strict=False by default
1 parent 2e874f1 commit c284e7c

File tree

4 files changed

+88
-79
lines changed

4 files changed

+88
-79
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
import transformers.cache_utils
66

77

8-
def flatten_unflatten_for_dynamic_shapes(obj: Any) -> Any:
8+
def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
99
"""
1010
Returns the object in a different structure similar to what
1111
the definition of the dynamic shapes should use.
1212
1313
:param obj: object from a custom class
14+
:param use_dict: closer to the original result but
15+
:func:`torch.export.export` only considers the values,
16+
the context gives the dictionary keys but it is not expressed
17+
in the dynamic shapes, these specifications seems to be different
18+
for the strict and non strict mode.
1419
:return: the serialized object
1520
"""
1621
flat, spec = torch.utils._pytree.tree_flatten(obj)
@@ -20,11 +25,11 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any) -> Any:
2025
for subspec in spec.children_specs:
2126
end += subspec.num_leaves
2227
value = subspec.unflatten(flat[start:end])
23-
if not isinstance(value, (torch.Tensor, list)):
24-
value = flatten_unflatten_for_dynamic_shapes(value)
28+
if subspec.type is dict:
29+
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
2530
subtrees.append(value)
2631
start = end
27-
if spec.context:
32+
if subspec.type is dict:
2833
# This a dictionary.
2934
return dict(zip(spec.context, subtrees))
3035
# This is a list.

onnx_diagnostic/helpers/helper.py

Lines changed: 66 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,72 @@ def string_type(
513513

514514
# others classes
515515

516+
if obj.__class__.__name__ == "MambaCache":
517+
c = string_type(
518+
obj.conv_states,
519+
with_shape=with_shape,
520+
with_min_max=with_min_max,
521+
with_device=with_device,
522+
limit=limit,
523+
verbose=verbose,
524+
)
525+
d = string_type(
526+
obj.ssm_states,
527+
with_shape=with_shape,
528+
with_min_max=with_min_max,
529+
with_device=with_device,
530+
limit=limit,
531+
verbose=verbose,
532+
)
533+
if verbose:
534+
print(f"[string_type] CACHE1:{type(obj)}")
535+
return f"MambaCache(conv_states={c}, ssm_states={d})"
536+
537+
if obj.__class__.__name__ == "DynamicCache":
538+
kc = string_type(
539+
obj.key_cache,
540+
with_shape=with_shape,
541+
with_min_max=with_min_max,
542+
with_device=with_device,
543+
limit=limit,
544+
verbose=verbose,
545+
)
546+
vc = string_type(
547+
obj.value_cache,
548+
with_shape=with_shape,
549+
with_min_max=with_min_max,
550+
with_device=with_device,
551+
limit=limit,
552+
verbose=verbose,
553+
)
554+
if verbose:
555+
print(f"[string_type] CACHE2:{type(obj)}")
556+
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
557+
558+
if obj.__class__.__name__ == "EncoderDecoderCache":
559+
att = string_type(
560+
obj.self_attention_cache,
561+
with_shape=with_shape,
562+
with_min_max=with_min_max,
563+
with_device=with_device,
564+
limit=limit,
565+
verbose=verbose,
566+
)
567+
cross = string_type(
568+
obj.cross_attention_cache,
569+
with_shape=with_shape,
570+
with_min_max=with_min_max,
571+
with_device=with_device,
572+
limit=limit,
573+
verbose=verbose,
574+
)
575+
if verbose:
576+
print(f"[string_type] CACHE3:{type(obj)}")
577+
return (
578+
f"{obj.__class__.__name__}(self_attention_cache={att}, "
579+
f"cross_attention_cache={cross})"
580+
)
581+
516582
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
517583
from .cache_helper import flatten_unflatten_for_dynamic_shapes
518584

@@ -595,74 +661,6 @@ def string_type(
595661
print(f"[string_type] TT8:{type(obj)}")
596662
return repr(obj).replace(" ", "").replace("\n", " ")
597663

598-
# to avoid failures
599-
600-
if obj.__class__.__name__ == "MambaCache":
601-
c = string_type(
602-
obj.conv_states,
603-
with_shape=with_shape,
604-
with_min_max=with_min_max,
605-
with_device=with_device,
606-
limit=limit,
607-
verbose=verbose,
608-
)
609-
d = string_type(
610-
obj.ssm_states,
611-
with_shape=with_shape,
612-
with_min_max=with_min_max,
613-
with_device=with_device,
614-
limit=limit,
615-
verbose=verbose,
616-
)
617-
if verbose:
618-
print(f"[string_type] CACHE1:{type(obj)}")
619-
return f"MambaCache(conv_states={c}, ssm_states={d})"
620-
621-
if obj.__class__.__name__ == "DynamicCache":
622-
kc = string_type(
623-
obj.key_cache,
624-
with_shape=with_shape,
625-
with_min_max=with_min_max,
626-
with_device=with_device,
627-
limit=limit,
628-
verbose=verbose,
629-
)
630-
vc = string_type(
631-
obj.value_cache,
632-
with_shape=with_shape,
633-
with_min_max=with_min_max,
634-
with_device=with_device,
635-
limit=limit,
636-
verbose=verbose,
637-
)
638-
if verbose:
639-
print(f"[string_type] CACHE2:{type(obj)}")
640-
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
641-
642-
if obj.__class__.__name__ == "EncoderDecoderCache":
643-
att = string_type(
644-
obj.self_attention_cache,
645-
with_shape=with_shape,
646-
with_min_max=with_min_max,
647-
with_device=with_device,
648-
limit=limit,
649-
verbose=verbose,
650-
)
651-
cross = string_type(
652-
obj.cross_attention_cache,
653-
with_shape=with_shape,
654-
with_min_max=with_min_max,
655-
with_device=with_device,
656-
limit=limit,
657-
verbose=verbose,
658-
)
659-
if verbose:
660-
print(f"[string_type] CACHE3:{type(obj)}")
661-
return (
662-
f"{obj.__class__.__name__}(self_attention_cache={att}, "
663-
f"cross_attention_cache={cross})"
664-
)
665-
666664
if ignore:
667665
if verbose:
668666
print(f"[string_type] CACHE4:{type(obj)}")

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ def get_inputs(
6969
)
7070
"""
7171
batch = torch.export.Dim("batch", min=1, max=1024)
72-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
72+
seq_length = "seq_length"
7373

7474
shapes = {
7575
"decoder_input_ids": {0: batch, 1: seq_length},
7676
"cache_position": {0: seq_length},
77-
"encoder_outputs": {"last_hidden_state": {0: batch}},
77+
"encoder_outputs": [{0: batch}], # last_hidden_state
7878
"past_key_values": [
7979
[
8080
[{0: batch} for _ in range(num_hidden_layers)],

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..helpers.torch_test_helper import to_any, torch_deepcopy
1313
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
1414
from ..torch_export_patches import bypass_export_some_errors
15+
from ..torch_export_patches.patch_inputs import use_dyn_not_str
1516
from .hghub import get_untrained_model_with_inputs
1617
from .hghub.model_inputs import random_input_kwargs
1718

@@ -633,14 +634,15 @@ def call_torch_export_export(
633634
another one with whatever the function produces
634635
"""
635636
assert exporter in {
637+
"export",
636638
"export-strict",
637639
"export-nostrict",
638640
}, f"Unexpected value for exporter={exporter!r}"
639641
assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
640642
assert "model" in data, f"model is missing from data: {sorted(data)}"
641643
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
642644
summary: Dict[str, Union[str, int, float]] = {}
643-
strict = "nostrict" not in exporter
645+
strict = "-strict" in exporter
644646
args, kwargs = split_args_kwargs(data["inputs_export"])
645647
ds = data.get("dynamic_shapes", None)
646648

@@ -652,7 +654,9 @@ def call_torch_export_export(
652654
summary["export_dynamic_shapes"] = string_type(ds)
653655

654656
# There is an issue with DynamicShape [[],[]] becomes []
655-
dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
657+
dse = use_dyn_not_str(ds)
658+
# dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
659+
656660
summary["export_dynamic_shapes_export_export"] = string_type(dse)
657661

658662
if verbose:
@@ -1015,7 +1019,7 @@ def call_torch_export_custom(
10151019
assert "model" in data, f"model is missing from data: {sorted(data)}"
10161020
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
10171021
summary: Dict[str, Union[str, int, float]] = {}
1018-
dynamo = "nostrict" not in exporter
1022+
strict = "-strict" in exporter
10191023
args, kwargs = split_args_kwargs(data["inputs_export"])
10201024
ds = data.get("dynamic_shapes", None)
10211025
if verbose:
@@ -1029,15 +1033,15 @@ def call_torch_export_custom(
10291033
print("[call_torch_export_custom] export...")
10301034
summary["export_exporter"] = exporter
10311035
summary["export_optimization"] = optimization or ""
1032-
summary["export_dynamo"] = dynamo
1036+
summary["export_strict"] = strict
10331037
summary["export_args"] = string_type(args, with_shape=True)
10341038
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
10351039

10361040
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
10371041
from experimental_experiment.xbuilder import OptimizationOptions
10381042

10391043
export_options = ExportOptions(
1040-
strict="nostrict" not in exporter,
1044+
strict=strict,
10411045
decomposition_table=(
10421046
"dec" if "-dec" in exporter else ("all" if "-all" in exporter else None)
10431047
),
@@ -1057,6 +1061,7 @@ def call_torch_export_custom(
10571061
optimize=bool(optimization),
10581062
large_model=True,
10591063
return_optimize_report=True,
1064+
verbose=max(verbose - 2, 0),
10601065
)
10611066
except Exception as e:
10621067
summary["ERR_export_export"] = str(e)
@@ -1074,6 +1079,7 @@ def call_torch_export_custom(
10741079
optimize=bool(optimization),
10751080
large_model=True,
10761081
return_optimize_report=True,
1082+
verbose=max(verbose - 2, 0),
10771083
)
10781084

10791085
new_stat = {}

0 commit comments

Comments
 (0)