Skip to content

Commit e49a75b

Browse files
committed
improve
1 parent feac0fa commit e49a75b

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_l_validate_model_modelbuilder(self):
203203
verbose=10,
204204
exporter="modelbuilder",
205205
dump_folder="dump_test/validate_model_modelbuilder",
206-
patch=True,
206+
patch=False,
207207
)
208208
self.assertIsInstance(summary, dict)
209209
self.assertIsInstance(data, dict)

onnx_diagnostic/torch_models/validate.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from ..helpers.torch_helper import to_any, torch_deepcopy
1717
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
1818
from ..tasks import random_input_kwargs
19-
from ..torch_export_patches import torch_export_patches
19+
from ..torch_export_patches import (
20+
torch_export_patches,
21+
register_additional_serialization_functions,
22+
)
2023
from ..torch_export_patches.patch_inputs import use_dyn_not_str
2124
from .hghub import get_untrained_model_with_inputs
2225
from .hghub.model_inputs import _preprocess_model_id
@@ -574,12 +577,7 @@ def validate_model(
574577
cpl = CoupleInputsDynamicShapes(
575578
tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
576579
)
577-
if patch_kwargs.get("patch", False):
578-
with torch_export_patches(**patch_kwargs): # type: ignore[arg-type]
579-
data[k] = cpl.change_dynamic_dimensions(
580-
desired_values=dict(batch=1), only_desired=True
581-
)
582-
else:
580+
with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
583581
data[k] = cpl.change_dynamic_dimensions(
584582
desired_values=dict(batch=1), only_desired=True
585583
)

0 commit comments

Comments
 (0)