Skip to content

Commit 576e12d

Browse files
authored
Fix use of pretrained version in validate (#158)
* really use pretrained version * add Phi3MoE * fix config * spell * fix order of inputs
1 parent 205b288 commit 576e12d

File tree

6 files changed

+156
-43
lines changed

6 files changed

+156
-43
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,24 @@ def get_parser_validate() -> ArgumentParser:
333333
of supported tasks.
334334
"""
335335
),
336-
epilog="If the model id is specified, one untrained version of it is instantiated.",
336+
epilog=textwrap.dedent(
337+
"""
338+
If the model id is specified, one untrained version of it is instantiated.
339+
Examples:
340+
341+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
342+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
343+
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
344+
345+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
346+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
347+
--dtype float16 --device cuda --patch --export custom --opt default
348+
349+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
350+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
351+
--dtype float16 --device cuda --export modelbuilder
352+
"""
353+
),
337354
formatter_class=RawTextHelpFormatter,
338355
)
339356
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
@@ -372,6 +389,12 @@ def get_parser_validate() -> ArgumentParser:
372389
type=int,
373390
help="Raises an exception if a dynamic dimension becomes static.",
374391
)
392+
parser.add_argument(
393+
"--same-as-trained",
394+
default=False,
395+
action=BooleanOptionalAction,
396+
help="Validates a model identical to the trained model but not trained.",
397+
)
375398
parser.add_argument(
376399
"--trained",
377400
default=False,
@@ -487,7 +510,8 @@ def _cmd_validate(argv: List[Any]):
487510
do_run=args.run,
488511
verbose=args.verbose,
489512
quiet=args.quiet,
490-
trained=args.trained,
513+
same_as_pretrained=args.same_as_trained,
514+
use_pretrained=args.trained,
491515
dtype=args.dtype,
492516
device=args.device,
493517
patch=args.patch,
@@ -619,7 +643,13 @@ def get_parser_agg() -> ArgumentParser:
619643
and produces values. Every row has a date.
620644
"""
621645
),
622-
epilog="example\n python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1",
646+
epilog=textwrap.dedent(
647+
"""
648+
examples:\n
649+
650+
python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
651+
"""
652+
),
623653
formatter_class=RawTextHelpFormatter,
624654
)
625655
parser.add_argument("output", help="output excel file")

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ def get_inputs(
9696
for i in range(num_hidden_layers)
9797
]
9898
),
99-
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
99+
pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
100100
torch.int64
101101
),
102-
pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
102+
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
103103
torch.int64
104104
),
105105
)
@@ -132,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
132132
If the configuration is None, the function selects typical dimensions.
133133
"""
134134
if config is not None:
135-
check_hasattr(
136-
config,
137-
"vocab_size",
138-
"hidden_size",
139-
"num_attention_heads",
140-
("num_key_value_heads", "num_attention_heads"),
141-
"intermediate_size",
142-
"hidden_size",
143-
"vision_config",
144-
)
135+
if hasattr(config, "text_config"):
136+
check_hasattr(
137+
config.text_config,
138+
"vocab_size",
139+
"hidden_size",
140+
"num_attention_heads",
141+
("num_key_value_heads", "num_attention_heads"),
142+
"intermediate_size",
143+
"hidden_size",
144+
)
145+
check_hasattr(config, "vision_config")
146+
text_config = True
147+
else:
148+
check_hasattr(
149+
config,
150+
"vocab_size",
151+
"hidden_size",
152+
"num_attention_heads",
153+
("num_key_value_heads", "num_attention_heads"),
154+
"intermediate_size",
155+
"hidden_size",
156+
"vision_config",
157+
)
158+
text_config = False
145159
check_hasattr(config.vision_config, "image_size", "num_channels")
146160
kwargs = dict(
147161
batch_size=2,
@@ -150,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
150164
head_dim=(
151165
16
152166
if config is None
153-
else getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
167+
else getattr(
168+
config,
169+
"head_dim",
170+
(config.text_config.hidden_size if text_config else config.hidden_size)
171+
// (
172+
config.text_config.num_attention_heads
173+
if text_config
174+
else config.num_attention_heads
175+
),
176+
)
177+
),
178+
dummy_max_token_id=(
179+
31999
180+
if config is None
181+
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
182+
),
183+
num_hidden_layers=(
184+
4
185+
if config is None
186+
else (
187+
config.text_config.num_hidden_layers
188+
if text_config
189+
else config.num_hidden_layers
190+
)
154191
),
155-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
156-
num_hidden_layers=4 if config is None else config.num_hidden_layers,
157192
num_key_value_heads=(
158193
8
159194
if config is None
160-
else _pick(config, "num_key_value_heads", "num_attention_heads")
195+
else (
196+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
197+
if text_config
198+
else _pick(config, "num_key_value_heads", "num_attention_heads")
199+
)
200+
),
201+
intermediate_size=(
202+
1024
203+
if config is None
204+
else (
205+
config.text_config.intermediate_size
206+
if text_config
207+
else config.intermediate_size
208+
)
209+
),
210+
hidden_size=(
211+
512
212+
if config is None
213+
else (config.text_config.hidden_size if text_config else config.hidden_size)
161214
),
162-
intermediate_size=1024 if config is None else config.intermediate_size,
163-
hidden_size=512 if config is None else config.hidden_size,
164215
width=224 if config is None else config.vision_config.image_size,
165216
height=224 if config is None else config.vision_config.image_size,
166217
num_channels=3 if config is None else config.vision_config.num_channels,

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,15 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
138138

139139

140140
@functools.cache
141-
def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
141+
def task_from_arch(
142+
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
143+
) -> str:
142144
"""
143145
This function relies on stored information. That information needs to be refresh.
144146
145147
:param arch: architecture name
146148
:param default_value: default value in case the task cannot be determined
149+
:param model_id: unused unless the architecture does not help.
147150
:return: task
148151
149152
.. runpython::
@@ -156,9 +159,16 @@ def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
156159
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
157160
"""
158161
data = load_architecture_task()
162+
if arch not in data and model_id:
163+
# Let's try with the model id.
164+
return task_from_id(model_id)
159165
if default_value is not None:
160166
return data.get(arch, default_value)
161-
assert arch in data, f"Architecture {arch!r} is unknown, last refresh in {__date__}"
167+
assert arch in data, (
168+
f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
169+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
170+
f"needs to be updated (model_id={(model_id or '?')!r})."
171+
)
162172
return data[arch]
163173

164174

@@ -176,6 +186,7 @@ def task_from_id(
176186
if the task cannot be determined
177187
:param pretrained: uses the config
178188
:param fall_back_to_pretrained: falls back to pretrained config
189+
:param exc: raises an exception if True
179190
:return: task
180191
"""
181192
if not pretrained:
@@ -191,9 +202,14 @@ def task_from_id(
191202
guess = _guess_task_from_config(config)
192203
if guess is not None:
193204
return guess
205+
data = load_architecture_task()
206+
if model_id in data:
207+
return data[model_id]
194208
assert config.architectures is not None and len(config.architectures) == 1, (
195209
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
196-
f"architectures={config.architectures} in config={config}"
210+
f"architectures={config.architectures} in config={config}. "
211+
f"The task can be added in "
212+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
197213
)
198214
return task_from_arch(config.architectures[0], default_value=default_value)
199215

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import textwrap
44
from typing import Dict, List
55

6-
__date__ = "2025-03-26"
6+
__date__ = "2025-06-21"
77

88
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
99

@@ -52,6 +52,8 @@
5252
GPTNeoModel,feature-extraction
5353
GPTNeoXForCausalLM,text-generation
5454
GemmaForCausalLM,text-generation
55+
Gemma2ForCausalLM,text-generation
56+
Gemma3ForConditionalGeneration,image-text-to-text
5557
GraniteForCausalLM,text-generation
5658
GroupViTModel,feature-extraction
5759
HieraForImageClassification,image-classification
@@ -97,6 +99,7 @@
9799
PegasusModel,feature-extraction
98100
Phi3ForCausalLM,text-generation
99101
PhiForCausalLM,text-generation
102+
PhiMoEForCausalLM,text-generation
100103
Pix2StructForConditionalGeneration,image-to-text
101104
PLBartForConditionalGeneration,text2text-generation
102105
PoolFormerModel,image-feature-extraction
@@ -144,7 +147,8 @@
144147
XLMRobertaModel,sentence-similarity
145148
Wav2Vec2ForCTC,automatic-speech-recognition
146149
YolosForObjectDetection,object-detection
147-
YolosModel,image-feature-extraction"""
150+
YolosModel,image-feature-extraction
151+
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
148152
)
149153

150154
__data_tasks__ = [

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import os
3+
import pprint
34
from typing import Any, Dict, Optional, Tuple
45
import torch
56
import transformers
@@ -22,6 +23,7 @@ def get_untrained_model_with_inputs(
2223
model_kwargs: Optional[Dict[str, Any]] = None,
2324
verbose: int = 0,
2425
dynamic_rope: Optional[bool] = None,
26+
use_pretrained: bool = False,
2527
same_as_pretrained: bool = False,
2628
use_preinstalled: bool = True,
2729
add_second_input: bool = False,
@@ -43,6 +45,7 @@ def get_untrained_model_with_inputs(
4345
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
4446
:param same_as_pretrained: if True, do not change the default values
4547
to get a smaller model
48+
:param use_pretrained: download the pretrained weights as well
4649
:param use_preinstalled: use preinstalled configurations
4750
:param add_second_input: provides a second inputs to check a model
4851
supports different shapes
@@ -68,6 +71,10 @@ def get_untrained_model_with_inputs(
6871
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
6972
print("-- configuration:", pprint.pformat(data['configuration']))
7073
"""
74+
assert not use_preinstalled or not use_only_preinstalled, (
75+
f"model_id={model_id!r}, pretinstalled model is only available "
76+
f"if use_only_preinstalled is False."
77+
)
7178
if verbose:
7279
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
7380
if use_preinstalled:
@@ -99,7 +106,7 @@ def get_untrained_model_with_inputs(
99106
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
100107
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
101108
if task is None:
102-
task = task_from_arch(archs[0])
109+
task = task_from_arch(archs[0], model_id=model_id)
103110
if verbose:
104111
print(f"[get_untrained_model_with_inputs] task={task!r}")
105112

@@ -114,7 +121,6 @@ def get_untrained_model_with_inputs(
114121
)
115122

116123
# updating the configuration
117-
118124
mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {}
119125
if model_kwargs:
120126
for k, v in model_kwargs.items():
@@ -139,27 +145,28 @@ def get_untrained_model_with_inputs(
139145
f"{config._attn_implementation!r}" # type: ignore[union-attr]
140146
)
141147

148+
if use_pretrained:
149+
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
150+
else:
151+
if archs is not None:
152+
model = getattr(transformers, archs[0])(config)
153+
else:
154+
assert same_as_pretrained and use_pretrained, (
155+
f"Model {model_id!r} cannot be built, the model cannot be built. "
156+
f"It must be downloaded. Use same_as_pretrained=True "
157+
f"and use_pretrained=True."
158+
)
159+
142160
# input kwargs
143161
kwargs, fct = random_input_kwargs(config, task)
144162
if verbose:
145163
print(f"[get_untrained_model_with_inputs] use fct={fct}")
146164
if os.environ.get("PRINT_CONFIG") in (1, "1"):
147-
import pprint
148-
149165
print(f"-- input kwargs for task {task!r}")
150166
pprint.pprint(kwargs)
151167
if inputs_kwargs:
152168
kwargs.update(inputs_kwargs)
153169

154-
if archs is not None:
155-
model = getattr(transformers, archs[0])(config)
156-
else:
157-
assert same_as_pretrained, (
158-
f"Model {model_id!r} cannot be built, the model cannot be built. "
159-
f"It must be downloaded. Use same_as_pretrained=True."
160-
)
161-
model = None
162-
163170
# This line is important. Some models may produce different
164171
# outputs even with the same inputs in training mode.
165172
model.eval()

0 commit comments

Comments
 (0)