Skip to content

Commit fa664f1

Browse files
authored
extract subfolder from modelid//subfolder (#256)
* preprpocesid//subfolder * changes
1 parent a3f9f83 commit fa664f1

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.7.14
55
++++++
66

7-
* :pr:`252`: adds new set of inputs for task texgt-generation
7+
* :pr:`256`: extract subfolder from modelid//subfolder
8+
* :pr:`252`: adds new sets of inputs for task texgt-generation
89
* :pr:`250`: add variables to track sequence nodes
910
* :pr:`249`: patches _maybe_broadcast to support a corner case
1011

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ def _code_needing_rewriting(model: Any) -> Any:
2525
return code_needing_rewriting(model)
2626

2727

28+
def _preprocess_model_id(
29+
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30+
) -> Tuple[str, Optional[str], bool, bool]:
31+
if subfolder or "//" not in model_id:
32+
return model_id, subfolder, same_as_pretrained, use_pretrained
33+
spl = model_id.split("//")
34+
if spl[-1] == "pretrained":
35+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
36+
if spl[-1] in {"transformer", "vae"}:
37+
# known subfolder
38+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39+
return model_id, subfolder, same_as_pretrained, use_pretrained
40+
41+
2842
def get_untrained_model_with_inputs(
2943
model_id: str,
3044
config: Optional[Any] = None,
@@ -85,8 +99,16 @@ def get_untrained_model_with_inputs(
8599
f"model_id={model_id!r}, preinstalled model is only available "
86100
f"if use_only_preinstalled is False."
87101
)
102+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
103+
model_id,
104+
subfolder,
105+
same_as_pretrained=same_as_pretrained,
106+
use_pretrained=use_pretrained,
107+
)
88108
if verbose:
89-
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
109+
print(
110+
f"[get_untrained_model_with_inputs] model_id={model_id!r}, subfolder={subfolder!r}"
111+
)
90112
if use_preinstalled:
91113
print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
92114
if config is None:

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..torch_export_patches import torch_export_patches
2020
from ..torch_export_patches.patch_inputs import use_dyn_not_str
2121
from .hghub import get_untrained_model_with_inputs
22+
from .hghub.model_inputs import _preprocess_model_id
2223

2324

2425
def empty(value: Any) -> bool:
@@ -289,20 +290,6 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
289290
return new_cfg
290291

291292

292-
def _preprocess_model_id(
293-
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
294-
) -> Tuple[str, Optional[str], bool, bool]:
295-
if subfolder or "//" not in model_id:
296-
return model_id, subfolder, same_as_pretrained, use_pretrained
297-
spl = model_id.split("//")
298-
if spl[-1] == "pretrained":
299-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
300-
if spl[-1] in {"transformer", "vae"}:
301-
# known subfolder
302-
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
303-
return model_id, subfolder, same_as_pretrained, use_pretrained
304-
305-
306293
def validate_model(
307294
model_id: str,
308295
task: Optional[str] = None,

0 commit comments

Comments
 (0)