Skip to content

Commit 548bd93

Browse files
committed
support // for known subfolders
1 parent 2783fe9 commit 548bd93

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

onnx_diagnostic/torch_models/validate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,16 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264
return new_cfg
265265

266266

267+
def _preprocess_model_id(model_id, subfolder):
268+
if subfolder or "//" not in model_id:
269+
return model_id, subfolder
270+
spl = model_id.split("//")
271+
if spl[-1] in {"transformer", "vae"}:
272+
# known subfolder
273+
return "//".join(spl[:-1]), spl[-1]
274+
return model_id, subfolder
275+
276+
267277
def validate_model(
268278
model_id: str,
269279
task: Optional[str] = None,
@@ -374,6 +384,7 @@ def validate_model(
374384
if ``runtime == 'ref'``,
375385
``orteval10`` increases the verbosity.
376386
"""
387+
model_id, subfolder = _preprocess_model_id(model_id, subfolder)
377388
if isinstance(patch, bool):
378389
patch_kwargs = (
379390
dict(patch_transformers=True, patch_diffusers=True, patch=True)

0 commit comments

Comments
 (0)