Skip to content

Commit e252d79

Browse files
committed
supports one more scenario
1 parent 30ef056 commit e252d79

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_get_untrained_model_with_inputs_phi_2(self):
7474
)
7575

7676
@hide_stdout()
77+
@ignore_errors(OSError) # connectitivies issues
7778
def test_get_untrained_model_with_inputs_beit(self):
7879
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"
7980
data = get_untrained_model_with_inputs(mid, verbose=1)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,11 @@ def get_parser_validate() -> ArgumentParser:
264264
action=BooleanOptionalAction,
265265
help="applies patches before exporting",
266266
)
267+
parser.add_argument(
268+
"--stop-if-static",
269+
default=0,
270+
help="raises an exception if a dynamic dimension becomes static",
271+
)
267272
parser.add_argument(
268273
"--trained",
269274
default=False,
@@ -319,6 +324,7 @@ def _cmd_validate(argv: List[Any]):
319324
dtype=args.dtype,
320325
device=args.device,
321326
patch=args.patch,
327+
stop_if_static=args.stop_if_static,
322328
optimization=args.opt,
323329
exporter=args.export,
324330
dump_folder=args.dump_folder,

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def validate_model(
202202
optimization: Optional[str] = None,
203203
quiet: bool = False,
204204
patch: bool = False,
205+
stop_if_static: bool = True,
205206
dump_folder: Optional[str] = None,
206207
drop_inputs: Optional[List[str]] = None,
207208
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
@@ -223,7 +224,10 @@ def validate_model(
223224
:param optimization: optimization to apply to the exported model,
224225
depend on the the exporter
225226
:param quiet: if quiet, catches exception if any issue
226-
:param patch: applies patches before exporting
227+
:param patch: applies patches (``patch_transformers=True``) before exporting,
228+
see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
229+
:param stop_if_static: stops if a dynamic dimension becomes static,
230+
see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
227231
:param dump_folder: dumps everything in a subfolder of this one
228232
:param drop_inputs: drops this list of inputs (given their names)
229233
:return: two dictionaries, one with some metrics,
@@ -354,7 +358,9 @@ def validate_model(
354358
if verbose:
355359
print("[validate_model] applies patches before exporting")
356360
with bypass_export_some_errors( # type: ignore
357-
patch_transformers=True, verbose=max(0, verbose - 1)
361+
patch_transformers=True,
362+
stop_if_static=stop_if_static,
363+
verbose=max(0, verbose - 1),
358364
) as modificator:
359365
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
360366

0 commit comments

Comments
 (0)