Skip to content

Commit f6d0d18

Browse files
committed
improve ci
1 parent 853c150 commit f6d0d18

File tree

4 files changed

+41
-17
lines changed

4 files changed

+41
-17
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,6 @@ jobs:
110110
pip install torch==${{ matrix.torch }} torchvision torchaudio
111111
fi
112112
113-
- name: Uninstall triton
114-
run: python -m pip uninstall -y triton
115-
116113
- name: Cache pip
117114
uses: actions/cache@v4
118115
with:

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,12 @@ def get_parser_validate() -> ArgumentParser:
553553
action=BooleanOptionalAction,
554554
help="Enables onnxruntime logging when the session is created",
555555
)
556+
parser.add_argument(
557+
"--quiet-input-sets",
558+
default="",
559+
help="Avoids raising an exception when an input sets does not work with "
560+
"the exported model, example: --quiet-input-sets=inputs,inputs22",
561+
)
556562
return parser
557563

558564

@@ -614,6 +620,7 @@ def _cmd_validate(argv: List[Any]):
614620
warmup=args.warmup,
615621
inputs2=args.inputs2,
616622
ort_logs=args.ort_logs,
623+
quiet_input_sets=set(args.quiet_input_sets.split(",")),
617624
output_names=(
618625
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
619626
),

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,26 @@
1212
StaticCache,
1313
)
1414

15-
try:
16-
from transformers.models.mamba.modeling_mamba import MambaCache
17-
except ImportError:
18-
from transformers.cache_utils import MambaCache
19-
2015
from ..helpers import string_type
2116
from .serialization import _lower_name_with_
2217

2318
PATCH_OF_PATCHES: Set[Any] = set()
2419

2520

21+
def get_mamba_cache_cls() -> type:
22+
try:
23+
from transformers.models.mamba.modeling_mamba import MambaCache
24+
25+
return MambaCache
26+
except ImportError:
27+
try:
28+
from transformers.cache_utils import MambaCache
29+
30+
return MambaCache
31+
except ImportError:
32+
return None
33+
34+
2635
def register_class_serialization(
2736
cls,
2837
f_flatten: Callable,
@@ -203,13 +212,6 @@ def serialization_functions(
203212
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204213
verbose=verbose,
205214
),
206-
MambaCache: lambda verbose=verbose: register_class_serialization(
207-
MambaCache,
208-
flatten_mamba_cache,
209-
unflatten_mamba_cache,
210-
flatten_with_keys_mamba_cache,
211-
verbose=verbose,
212-
),
213215
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
214216
EncoderDecoderCache,
215217
flatten_encoder_decoder_cache,
@@ -232,6 +234,17 @@ def serialization_functions(
232234
verbose=verbose,
233235
),
234236
}
237+
MambaCache = get_mamba_cache_cls()
238+
if MambaCache:
239+
transformers_classes[MambaCache] = (
240+
lambda verbose=verbose: register_class_serialization(
241+
MambaCache,
242+
flatten_mamba_cache,
243+
unflatten_mamba_cache,
244+
flatten_with_keys_mamba_cache,
245+
verbose=verbose,
246+
)
247+
)
235248
classes.update(transformers_classes)
236249

237250
if patch_diffusers:
@@ -287,7 +300,12 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
287300

288301
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
289302
"""Undo all registrations."""
290-
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
303+
MambaCache = get_mamba_cache_cls()
304+
cls_ensemble = (
305+
{DynamicCache, EncoderDecoderCache}
306+
| set(undo)
307+
| ({MambaCache} if MambaCache else set())
308+
)
291309
for cls in cls_ensemble:
292310
if undo.get(cls.__name__, False):
293311
unregister_class_serialization(cls, verbose)

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,8 @@ def _mk(key, flavour=flavour):
14381438
keys = [("inputs", "run_expected", "")]
14391439
if second_input_keys:
14401440
keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
1441+
if verbose:
1442+
print(f"[validate_onnx_model] -- keys={keys}")
14411443
for k_input, k_expected, suffix in keys:
14421444
# make_feeds
14431445
assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
@@ -1463,7 +1465,7 @@ def _mk(key, flavour=flavour):
14631465
# run ort
14641466
if verbose:
14651467
print(f"[validate_onnx_model] run session on inputs 'inputs{suffix}'...")
1466-
if quiet_input_sets:
1468+
if quiet_input_sets and f"inputs{suffix}" in quiet_input_sets:
14671469
print(f"[validate_onnx_model] quiet_input_sets={quiet_input_sets}")
14681470

14691471
got = _quiet_or_not_quiet(

0 commit comments

Comments
 (0)