Skip to content

Commit a7b3d4a

Browse files
committed
fix none value
1 parent f703ad6 commit a7b3d4a

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def get_untrained_model_with_inputs(
189189
f"subfolder={subfolder!r}"
190190
)
191191
model = transformers.AutoModel.from_pretrained(
192-
model_id, subfolder=subfolder, trust_remote_code=True, **mkwargs
192+
model_id, subfolder=subfolder or "", trust_remote_code=True, **mkwargs
193193
)
194194
if verbose:
195195
print(

onnx_diagnostic/torch_models/validate.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264
return new_cfg
265265

266266

267-
def _preprocess_model_id(model_id, subfolder):
267+
def _preprocess_model_id(
268+
model_id: str, subfolder: str, same_as_pretrained: bool, use_pretrained: bool
269+
) -> Tuple[str, str, bool, bool]:
268270
if subfolder or "//" not in model_id:
269-
return model_id, subfolder
271+
return model_id, subfolder, same_as_pretrained, use_pretrained
270272
spl = model_id.split("//")
273+
if spl[-1] == "pretrained":
274+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
271275
if spl[-1] in {"transformer", "vae"}:
272276
# known subfolder
273-
return "//".join(spl[:-1]), spl[-1]
274-
return model_id, subfolder
277+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
278+
return model_id, subfolder, same_as_pretrained, use_pretrained
275279

276280

277281
def validate_model(
@@ -384,7 +388,12 @@ def validate_model(
384388
if ``runtime == 'ref'``,
385389
``orteval10`` increases the verbosity.
386390
"""
387-
model_id, subfolder = _preprocess_model_id(model_id, subfolder)
391+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
392+
model_id,
393+
subfolder,
394+
same_as_pretrained=same_as_pretrained,
395+
use_pretrained=use_pretrained,
396+
)
388397
if isinstance(patch, bool):
389398
patch_kwargs = (
390399
dict(patch_transformers=True, patch_diffusers=True, patch=True)

0 commit comments

Comments
 (0)