Skip to content

Commit afebca5

Browse files
committed
allow to convert a model with other task inputs
1 parent 2b5cb7c commit afebca5

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
def reduce_model_config(config: Any) -> Dict[str, Any]:
2020
"""Reduces a model size."""
2121
# FalconMambaConfig: use_mambapy
22+
if hasattr(config, "text_config"):
23+
# The model is probably of mixture of models used only for text.
24+
config = config.text_config
2225
check_hasattr(
2326
config,
2427
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
@@ -308,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
308311
309312
If the configuration is None, the function selects typical dimensions.
310313
"""
314+
if hasattr(config, "text_config"):
315+
# The model is probably of mixture of models used only for text.
316+
config = config.text_config
311317
if config is not None:
312318
check_hasattr(
313319
config,

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,16 @@ def get_untrained_model_with_inputs(
120120
**(model_kwargs or {}),
121121
)
122122

123-
model, task, mkwargs, diff_config = None, None, {}, None
123+
model, task_, mkwargs, diff_config = None, None, {}, None
124124
if use_pretrained and same_as_pretrained:
125125
if model_id in HANDLED_MODELS:
126-
model, task, config = load_specific_model(model_id, verbose=verbose)
126+
model, task_, config = load_specific_model(model_id, verbose=verbose)
127127

128+
if task is None:
129+
task = task_
128130
if model is None:
129131
arch = architecture_from_config(config)
130-
if arch is None:
132+
if task is None and arch is None:
131133
task = task_from_id(model_id, subfolder=subfolder)
132134
assert task is not None or arch is not None, (
133135
f"Unable to determine the architecture for model {model_id!r}, "

onnx_diagnostic/torch_models/validate.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,21 @@ def _make_folder_name(
117117
drop_inputs: Optional[List[str]] = None,
118118
same_as_pretrained: bool = False,
119119
use_pretrained: bool = False,
120+
task: Optional[str] = None,
120121
) -> str:
121122
"Creates a filename unique based on the given options."
122123
els = [model_id.replace("/", "_")]
123124
if subfolder:
124125
els.append(subfolder.replace("/", "_"))
126+
if not task:
127+
els.append(task)
128+
if drop_inputs:
129+
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
130+
els.append(f"I-{ii.upper()}")
131+
if use_pretrained:
132+
els.append("TRAINED")
133+
elif same_as_pretrained:
134+
els.append("SAMESIZE")
125135
if exporter:
126136
els.append(exporter)
127137
if optimization:
@@ -142,14 +152,7 @@ def _make_folder_name(
142152
els.append(sdev)
143153
if opset is not None:
144154
els.append(f"op{opset}")
145-
if drop_inputs:
146-
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
147-
els.append(f"I-{ii.upper()}")
148-
if use_pretrained:
149-
els.append("TRAINED")
150-
elif same_as_pretrained:
151-
els.append("SAMESIZE")
152-
return "-".join(els)
155+
return "/".join(els)
153156

154157

155158
def version_summary() -> Dict[str, Union[int, float, str]]:
@@ -476,6 +479,7 @@ def validate_model(
476479
drop_inputs=drop_inputs,
477480
use_pretrained=use_pretrained,
478481
same_as_pretrained=same_as_pretrained,
482+
task=task,
479483
)
480484
dump_folder = os.path.join(dump_folder, folder_name)
481485
if not os.path.exists(dump_folder):
@@ -490,6 +494,8 @@ def validate_model(
490494
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
491495
else:
492496
print(f"[validate_model] validate model id {model_id!r}")
497+
if task:
498+
print(f"[validate_model] with task {task!r}")
493499
print(f"[validate_model] patch={patch!r}")
494500
if model_options:
495501
print(f"[validate_model] model_options={model_options!r}")
@@ -765,6 +771,7 @@ def validate_model(
765771
ep = data["exported_program"]
766772
if verbose:
767773
print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
774+
folder_name = folder_name.replace("/", "-")
768775
with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
769776
f.write(str(ep))
770777
torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
@@ -773,6 +780,7 @@ def validate_model(
773780
if verbose:
774781
print("[validate_model] done (dump ep)")
775782
if "onnx_program" in data:
783+
folder_name = folder_name.replace("/", "-")
776784
epo = data["onnx_program"]
777785
if verbose:
778786
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")

0 commit comments

Comments
 (0)