Skip to content

Commit 1480fdf

Browse files
committed
really use pretrained version
1 parent 432b14d commit 1480fdf

File tree

6 files changed

+157
-42
lines changed

6 files changed

+157
-42
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,24 @@ def get_parser_validate() -> ArgumentParser:
333333
of supported tasks.
334334
"""
335335
),
336-
epilog="If the model id is specified, one untrained version of it is instantiated.",
336+
epilog=textwrap.dedent(
337+
"""
338+
If the model id is specified, one untrained version of it is instantiated.
339+
Examples:
340+
341+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
342+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
343+
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
344+
345+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
346+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
347+
--dtype float16 --device cuda --patch --export custom --opt default
348+
349+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
350+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
351+
--dtype float16 --device cuda --export modelbuilder
352+
"""
353+
),
337354
formatter_class=RawTextHelpFormatter,
338355
)
339356
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
@@ -372,6 +389,12 @@ def get_parser_validate() -> ArgumentParser:
372389
type=int,
373390
help="Raises an exception if a dynamic dimension becomes static.",
374391
)
392+
parser.add_argument(
393+
"--same-as-trained",
394+
default=False,
395+
action=BooleanOptionalAction,
396+
help="Validates a model identical to the trained model but not trained.",
397+
)
375398
parser.add_argument(
376399
"--trained",
377400
default=False,
@@ -487,7 +510,8 @@ def _cmd_validate(argv: List[Any]):
487510
do_run=args.run,
488511
verbose=args.verbose,
489512
quiet=args.quiet,
490-
trained=args.trained,
513+
same_as_pretrained=args.same_as_trained,
514+
use_pretrained=args.trained,
491515
dtype=args.dtype,
492516
device=args.device,
493517
patch=args.patch,
@@ -619,7 +643,13 @@ def get_parser_agg() -> ArgumentParser:
619643
and produces values. Every row has a date.
620644
"""
621645
),
622-
epilog="example\n python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1",
646+
epilog=textwrap.dedent(
647+
"""
648+
examples:\n
649+
650+
python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
651+
"""
652+
),
623653
formatter_class=RawTextHelpFormatter,
624654
)
625655
parser.add_argument("output", help="output excel file")

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
132132
If the configuration is None, the function selects typical dimensions.
133133
"""
134134
if config is not None:
135-
check_hasattr(
136-
config,
137-
"vocab_size",
138-
"hidden_size",
139-
"num_attention_heads",
140-
("num_key_value_heads", "num_attention_heads"),
141-
"intermediate_size",
142-
"hidden_size",
143-
"vision_config",
144-
)
135+
if hasattr(config, "text_config"):
136+
check_hasattr(
137+
config.text_config,
138+
"vocab_size",
139+
"hidden_size",
140+
"num_attention_heads",
141+
("num_key_value_heads", "num_attention_heads"),
142+
"intermediate_size",
143+
"hidden_size",
144+
)
145+
check_hasattr(config, "vision_config")
146+
text_config = True
147+
else:
148+
check_hasattr(
149+
config,
150+
"vocab_size",
151+
"hidden_size",
152+
"num_attention_heads",
153+
("num_key_value_heads", "num_attention_heads"),
154+
"intermediate_size",
155+
"hidden_size",
156+
"vision_config",
157+
)
158+
text_config = False
145159
check_hasattr(config.vision_config, "image_size", "num_channels")
146160
kwargs = dict(
147161
batch_size=2,
@@ -150,17 +164,55 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
150164
head_dim=(
151165
16
152166
if config is None
153-
else getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
167+
else getattr(
168+
config,
169+
"head_dim",
170+
(config.text_config.hidden_size if text_config else config.config.hidden_size)
171+
// (
172+
config.text_config.num_attention_heads
173+
if text_config
174+
else config.config.num_attention_heads
175+
),
176+
)
177+
),
178+
dummy_max_token_id=(
179+
31999
180+
if config is None
181+
else (config.text_config.vocab_size if text_config else config.config.vocab_size)
182+
- 1
183+
),
184+
num_hidden_layers=(
185+
4
186+
if config is None
187+
else (
188+
config.text_config.num_hidden_layers
189+
if text_config
190+
else config.config.num_hidden_layers
191+
)
154192
),
155-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
156-
num_hidden_layers=4 if config is None else config.num_hidden_layers,
157193
num_key_value_heads=(
158194
8
159195
if config is None
160-
else _pick(config, "num_key_value_heads", "num_attention_heads")
196+
else (
197+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
198+
if text_config
199+
else _pick(config, "num_key_value_heads", "num_attention_heads")
200+
)
201+
),
202+
intermediate_size=(
203+
1024
204+
if config is None
205+
else (
206+
config.text_config.intermediate_size
207+
if text_config
208+
else config.config.intermediate_size
209+
)
210+
),
211+
hidden_size=(
212+
512
213+
if config is None
214+
else (config.text_config.hidden_size if text_config else config.hidden_size)
161215
),
162-
intermediate_size=1024 if config is None else config.intermediate_size,
163-
hidden_size=512 if config is None else config.hidden_size,
164216
width=224 if config is None else config.vision_config.image_size,
165217
height=224 if config is None else config.vision_config.image_size,
166218
num_channels=3 if config is None else config.vision_config.num_channels,

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,15 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
138138

139139

140140
@functools.cache
141-
def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
141+
def task_from_arch(
142+
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
143+
) -> str:
142144
"""
143145
This function relies on stored information. That information needs to be refresh.
144146
145147
:param arch: architecture name
146148
:param default_value: default value in case the task cannot be determined
149+
:param model_id: unused unless the architecture does not help.
147150
:return: task
148151
149152
.. runpython::
@@ -156,9 +159,16 @@ def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
156159
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
157160
"""
158161
data = load_architecture_task()
162+
if arch not in data and model_id:
163+
# Let's try with the model id.
164+
return task_from_id(model_id)
159165
if default_value is not None:
160166
return data.get(arch, default_value)
161-
assert arch in data, f"Architecture {arch!r} is unknown, last refresh in {__date__}"
167+
assert arch in data, (
168+
f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
169+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
170+
f"needs to be updated (model_id={(model_id or '?')!r})."
171+
)
162172
return data[arch]
163173

164174

@@ -176,6 +186,7 @@ def task_from_id(
176186
if the task cannot be determined
177187
:param pretrained: uses the config
178188
:param fall_back_to_pretrained: falls back to pretrained config
189+
:param exc: raises an excpetion if True
179190
:return: task
180191
"""
181192
if not pretrained:
@@ -191,11 +202,18 @@ def task_from_id(
191202
guess = _guess_task_from_config(config)
192203
if guess is not None:
193204
return guess
205+
data = load_architecture_task()
206+
if model_id in data:
207+
return data[model_id]
194208
assert config.architectures is not None and len(config.architectures) == 1, (
195209
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
196-
f"architectures={config.architectures} in config={config}"
210+
f"architectures={config.architectures} in config={config}. "
211+
f"The task can be added in "
212+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
213+
)
214+
return task_from_arch(
215+
config.architectures[0], default_value=default_value, model_id=model_id
197216
)
198-
return task_from_arch(config.architectures[0], default_value=default_value)
199217

200218

201219
def task_from_tags(tags: Union[str, List[str]]) -> str:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import textwrap
44
from typing import Dict, List
55

6-
__date__ = "2025-03-26"
6+
__date__ = "2025-06-21"
77

88
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
99

@@ -52,6 +52,8 @@
5252
GPTNeoModel,feature-extraction
5353
GPTNeoXForCausalLM,text-generation
5454
GemmaForCausalLM,text-generation
55+
Gemma2ForCausalLM,text-generation
56+
Gemma3ForConditionalGeneration,image-text-to-text
5557
GraniteForCausalLM,text-generation
5658
GroupViTModel,feature-extraction
5759
HieraForImageClassification,image-classification
@@ -144,7 +146,8 @@
144146
XLMRobertaModel,sentence-similarity
145147
Wav2Vec2ForCTC,automatic-speech-recognition
146148
YolosForObjectDetection,object-detection
147-
YolosModel,image-feature-extraction"""
149+
YolosModel,image-feature-extraction
150+
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
148151
)
149152

150153
__data_tasks__ = [

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import os
3+
import pprint
34
from typing import Any, Dict, Optional, Tuple
45
import torch
56
import transformers
@@ -22,6 +23,7 @@ def get_untrained_model_with_inputs(
2223
model_kwargs: Optional[Dict[str, Any]] = None,
2324
verbose: int = 0,
2425
dynamic_rope: Optional[bool] = None,
26+
use_pretrained: bool = False,
2527
same_as_pretrained: bool = False,
2628
use_preinstalled: bool = True,
2729
add_second_input: bool = False,
@@ -43,6 +45,7 @@ def get_untrained_model_with_inputs(
4345
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
4446
:param same_as_pretrained: if True, do not change the default values
4547
to get a smaller model
48+
:param use_pretrained: download the pretrained weights as well
4649
:param use_preinstalled: use preinstalled configurations
4750
:param add_second_input: provides a second inputs to check a model
4851
supports different shapes
@@ -68,6 +71,10 @@ def get_untrained_model_with_inputs(
6871
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
6972
print("-- configuration:", pprint.pformat(data['configuration']))
7073
"""
74+
assert not use_preinstalled or not use_only_preinstalled, (
75+
f"model_id={model_id!r}, pretinstalled model is only avaialble "
76+
f"if use_only_preinstalled is False."
77+
)
7178
if verbose:
7279
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
7380
if use_preinstalled:
@@ -99,7 +106,7 @@ def get_untrained_model_with_inputs(
99106
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
100107
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
101108
if task is None:
102-
task = task_from_arch(archs[0])
109+
task = task_from_arch(archs[0], model_id=model_id)
103110
if verbose:
104111
print(f"[get_untrained_model_with_inputs] task={task!r}")
105112

@@ -114,7 +121,6 @@ def get_untrained_model_with_inputs(
114121
)
115122

116123
# updating the configuration
117-
118124
mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {}
119125
if model_kwargs:
120126
for k, v in model_kwargs.items():
@@ -139,27 +145,28 @@ def get_untrained_model_with_inputs(
139145
f"{config._attn_implementation!r}" # type: ignore[union-attr]
140146
)
141147

148+
if use_pretrained:
149+
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
150+
else:
151+
if archs is not None:
152+
model = getattr(transformers, archs[0])(config)
153+
else:
154+
assert same_as_pretrained and use_pretrained, (
155+
f"Model {model_id!r} cannot be built, the model cannot be built. "
156+
f"It must be downloaded. Use same_as_pretrained=True "
157+
f"and use_pretrained=True."
158+
)
159+
142160
# input kwargs
143161
kwargs, fct = random_input_kwargs(config, task)
144162
if verbose:
145163
print(f"[get_untrained_model_with_inputs] use fct={fct}")
146164
if os.environ.get("PRINT_CONFIG") in (1, "1"):
147-
import pprint
148-
149165
print(f"-- input kwargs for task {task!r}")
150166
pprint.pprint(kwargs)
151167
if inputs_kwargs:
152168
kwargs.update(inputs_kwargs)
153169

154-
if archs is not None:
155-
model = getattr(transformers, archs[0])(config)
156-
else:
157-
assert same_as_pretrained, (
158-
f"Model {model_id!r} cannot be built, the model cannot be built. "
159-
f"It must be downloaded. Use same_as_pretrained=True."
160-
)
161-
model = None
162-
163170
# This line is important. Some models may produce different
164171
# outputs even with the same inputs in training mode.
165172
model.eval()

onnx_diagnostic/torch_models/validate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ def validate_model(
259259
verbose: int = 0,
260260
dtype: Optional[Union[str, torch.dtype]] = None,
261261
device: Optional[Union[str, torch.device]] = None,
262-
trained: bool = False,
262+
same_as_pretrained: bool = False,
263+
use_pretrained: bool = False,
263264
optimization: Optional[str] = None,
264265
quiet: bool = False,
265266
patch: bool = False,
@@ -294,7 +295,9 @@ def validate_model(
294295
:param verbose: verbosity level
295296
:param dtype: uses this dtype to check the model
296297
:param device: do the verification on this device
297-
:param trained: use the trained model, not the untrained one
298+
:param same_as_pretrained: use a model equivalent to the trained,
299+
this is not always possible
300+
:param use_pretrained: use the trained model, not the untrained one
298301
:param optimization: optimization to apply to the exported model,
299302
depend on the the exporter
300303
:param quiet: if quiet, catches exception if any issue
@@ -353,7 +356,8 @@ def validate_model(
353356
version_do_run=str(do_run),
354357
version_dtype=str(dtype or ""),
355358
version_device=str(device or ""),
356-
version_trained=str(trained),
359+
version_same_as_pretrained=str(same_as_pretrained),
360+
version_use_pretrained=str(use_pretrained),
357361
version_optimization=optimization or "",
358362
version_quiet=str(quiet),
359363
version_patch=str(patch),
@@ -408,11 +412,12 @@ def validate_model(
408412
summary,
409413
None,
410414
(
411-
lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop, sub=subfolder, i2=inputs2: ( # noqa: E501
415+
lambda mid=model_id, v=verbose, task=task, uptr=use_pretrained, tr=same_as_pretrained, iop=iop, sub=subfolder, i2=inputs2: ( # noqa: E501
412416
get_untrained_model_with_inputs(
413417
mid,
414418
verbose=v,
415419
task=task,
420+
use_pretrained=uptr,
416421
same_as_pretrained=tr,
417422
inputs_kwargs=iop,
418423
model_kwargs=mop,

0 commit comments

Comments
 (0)