Skip to content

Commit 4eb5235

Browse files
committed
fix mispelling
1 parent b3acf09 commit 4eb5235

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
ExtTestCase,
55
ignore_warnings,
66
requires_transformers,
7-
requires_pytorch,
7+
requires_torch,
88
)
99
from onnx_diagnostic.torch_models.llms import get_phi2
1010
from onnx_diagnostic.helpers import string_type
@@ -19,7 +19,7 @@ def test_get_phi2(self):
1919

2020
@ignore_warnings(UserWarning)
2121
@requires_transformers("4.54")
22-
@requires_pytorch("2.9.99")
22+
@requires_torch("2.9.99")
2323
def test_export_phi2_1(self):
2424
# exporting vmap does not work
2525
data = get_phi2(num_hidden_layers=2)

onnx_diagnostic/torch_models/validate.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,22 +349,23 @@ def validate_model(
349349
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
350350
"""
351351
if isinstance(patch, bool):
352-
patch = (
352+
patch_kwargs = (
353353
dict(patch_transformers=True, patch_diffusers=True, patch=True)
354354
if patch
355355
else dict(patch=False)
356356
)
357357
elif isinstance(patch, str):
358-
patch = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
358+
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
359359
else:
360360
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
361-
patch = patch.copy()
362-
if "patch" not in patch:
363-
if any(patch.values):
364-
patch["patch"] = True
365-
366-
assert not rewrite or patch, (
367-
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "
361+
patch_kwargs = patch.copy()
362+
if "patch" not in patch_kwargs:
363+
if any(patch_kwargs.values()):
364+
patch_kwargs["patch"] = True
365+
366+
assert not rewrite or patch_kwargs.get("patch", False), (
367+
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
368+
f"patch must be True to enable rewriting, "
368369
f"if --no-patch was specified on the command line, --no-rewrite must be added."
369370
)
370371
summary = version_summary()
@@ -379,6 +380,7 @@ def validate_model(
379380
version_optimization=optimization or "",
380381
version_quiet=str(quiet),
381382
version_patch=str(patch),
383+
version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
382384
version_rewrite=str(rewrite),
383385
version_dump_folder=dump_folder or "",
384386
version_drop_inputs=str(list(drop_inputs or "")),
@@ -414,7 +416,7 @@ def validate_model(
414416
print(f"[validate_model] model_options={model_options!r}")
415417
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
416418
print(
417-
f"[validate_model] rewrite={rewrite}, patch={patch}, "
419+
f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
418420
f"stop_if_static={stop_if_static}"
419421
)
420422
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
@@ -590,7 +592,7 @@ def validate_model(
590592
f"[validate_model] -- export the model with {exporter!r}, "
591593
f"optimization={optimization!r}"
592594
)
593-
if patch:
595+
if patch_kwargs:
594596
if verbose:
595597
print(
596598
f"[validate_model] applies patches before exporting "
@@ -601,7 +603,7 @@ def validate_model(
601603
verbose=max(0, verbose - 1),
602604
rewrite=data.get("rewrite", None),
603605
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
604-
**patch,
606+
**patch_kwargs,
605607
) as modificator:
606608
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
607609

0 commit comments

Comments
 (0)