Skip to content

Commit 94bae15

Browse files
authored
Supports model options in command lines (#70)
* Support model options in command lines * spell
1 parent 9d5f2a4 commit 94bae15

21 files changed

+173
-36
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.4.1
5+
+++++
6+
7+
* :pr:`70`: support models options in command lines
8+
49
0.4.0
510
+++++
611

_doc/api/tasks/index.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
11
onnx_diagnostic.tasks
22
=====================
33

4+
All submodules contains the three following functions:
5+
6+
* ``reduce_model_config(config) -> kwargs``:
7+
updates the configuration to get a smaller model more suitable
8+
for unit tests
9+
* ``random_input_kwargs(config) -> kwargs, get_inputs``:
10+
produces values ``get_inputs`` can take to generate dummy inputs
11+
suitable for a model defined by its configuration
12+
* ``get_inputs(model, config, *args, **kwargs) -> dict(inputs=..., dynamic_shapes=...)``:
13+
generates the dummy inputs and dynamic shapes for a specific model and configuration.
14+
15+
For a specific task, you would write:
16+
17+
.. code-block:: python
18+
19+
kwargs, get_inputs = random_input_kwargs(config)
20+
dummies = get_inputs(model, config, **kwargs)
21+
22+
Or:
23+
24+
.. code-block:: python
25+
26+
from onnx_diagnostic.tasks import random_input_kwargs
27+
28+
kwargs, get_inputs = random_input_kwargs(config, task) # "text-generation" for example
29+
dummies = get_inputs(model, config, **kwargs)
30+
431
.. toctree::
532
:maxdepth: 1
633
:caption: modules

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Size of the package:
173173
Older versions
174174
++++++++++++++
175175

176+
* `0.4.1 <../v0.4.1/index.html>`_
176177
* `0.4.0 <../v0.4.0/index.html>`_
177178
* `0.3.0 <../v0.3.0/index.html>`_
178179
* `0.2.2 <../v0.2.2/index.html>`_

_unittests/ut_tasks/try_tasks.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ def test_text2text_generation(self):
9898
)
9999
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
100100

101+
@never_test()
102+
def test_text_generation_phi4(self):
103+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4
104+
105+
import torch
106+
from transformers import RobertaTokenizer, T5ForConditionalGeneration
107+
108+
tokenizer = RobertaTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
109+
model = T5ForConditionalGeneration.from_pretrained("microsoft/Phi-4-mini-instruct")
110+
111+
text = "def greet(user): print(f'hello <extra_id_0>!')"
112+
input_ids = tokenizer(text, return_tensors="pt").input_ids
113+
mask = (
114+
torch.tensor([1 for i in range(input_ids.shape[1])])
115+
.to(torch.int64)
116+
.reshape((1, -1))
117+
)
118+
119+
# simply generate a single sequence
120+
print()
121+
with steal_forward(model):
122+
generated_ids = model.generate(
123+
decoder_input_ids=input_ids, attention_mask=mask, max_length=100
124+
)
125+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
126+
101127
@never_test()
102128
def test_imagetext2text_generation(self):
103129
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k etext2t

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,14 @@ def test_task_from_id_long(self):
7474
def test_get_pretrained_config(self):
7575
conf = get_pretrained_config("microsoft/phi-2")
7676
self.assertNotEmpty(conf)
77-
print(conf)
77+
78+
@requires_transformers("4.50")
79+
@requires_torch("2.7")
80+
@hide_stdout()
81+
def test_get_pretrained_config_options(self):
82+
conf = get_pretrained_config("microsoft/phi-2", num_key_value_heads=16)
83+
self.assertNotEmpty(conf)
84+
self.assertEqual(conf.num_key_value_heads, 16)
7885

7986
@requires_transformers("4.50")
8087
@requires_torch("2.7")

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.4.0"
6+
__version__ = "0.4.1"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/_command_lines_parser.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,22 @@ def get_parser_config() -> ArgumentParser:
214214
action=BooleanOptionalAction,
215215
help="displays the task as well",
216216
)
217+
parser.add_argument(
218+
"-c",
219+
"--cached",
220+
default=True,
221+
action=BooleanOptionalAction,
222+
help="uses cached configuration, only available for some of them, "
223+
"mostly for unit test purposes",
224+
)
225+
parser.add_argument(
226+
"--mop",
227+
metavar="KEY=VALUE",
228+
nargs="*",
229+
help="Additional model options, use to change some parameters of the model, "
230+
"example: --mop attn_implementation=eager",
231+
action=_ParseDict,
232+
)
217233
return parser
218234

219235

@@ -222,7 +238,11 @@ def _cmd_config(argv: List[Any]):
222238

223239
parser = get_parser_config()
224240
args = parser.parse_args(argv[1:])
225-
print(get_pretrained_config(args.mid))
241+
conf = get_pretrained_config(args.mid, **(args.mop or {}))
242+
print(conf)
243+
for k, v in sorted(conf.__dict__.items()):
244+
if "_implementation" in k:
245+
print(f"config.{k}={v!r}")
226246
if args.task:
227247
print("------")
228248
print(f"task: {task_from_id(args.mid)}")
@@ -238,6 +258,19 @@ def __call__(self, parser, namespace, values, option_string=None):
238258
key = split_items[0].strip() # we remove blanks around keys, as is logical
239259
value = split_items[1]
240260

261+
if value in ("True", "true", "False", "false"):
262+
d[key] = bool(value)
263+
continue
264+
try:
265+
d[key] = int(value)
266+
continue
267+
except (TypeError, ValueError):
268+
pass
269+
try:
270+
d[key] = float(value)
271+
continue
272+
except (TypeError, ValueError):
273+
pass
241274
d[key] = value
242275

243276
setattr(namespace, self.dest, d)
@@ -321,6 +354,14 @@ def get_parser_validate() -> ArgumentParser:
321354
"inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322355
action=_ParseDict,
323356
)
357+
parser.add_argument(
358+
"--mop",
359+
metavar="KEY=VALUE",
360+
nargs="*",
361+
help="Additional model options, use to change some parameters of the model, "
362+
"example: --mop attn_implementation=eager",
363+
action=_ParseDict,
364+
)
324365
return parser
325366

326367

@@ -371,6 +412,7 @@ def _cmd_validate(argv: List[Any]):
371412
drop_inputs=None if not args.drop else args.drop.split(","),
372413
ortfusiontype=args.ortfusiontype,
373414
input_options=args.iop,
415+
model_options=args.mop,
374416
)
375417
print("")
376418
print("-- summary --")

onnx_diagnostic/helpers/config_helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]):
2828
def update_config(config: Any, mkwargs: Dict[str, Any]):
2929
"""Updates a configuration with different values."""
3030
for k, v in mkwargs.items():
31+
if k == "attn_implementation":
32+
config._attn_implementation = v
33+
if getattr(config, "_attn_implementation_autoset", False):
34+
config._attn_implementation_autoset = False
35+
continue
3136
if isinstance(v, dict):
3237
assert hasattr(
3338
config, k
3439
), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
3540
update_config(getattr(config, k), v)
36-
else:
37-
setattr(config, k, v)
41+
continue
42+
setattr(config, k, v)
3843

3944

4045
def _pick(config, *atts):

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
3333
"""Reduces a model size."""
3434
tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
3535
assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
36-
return tasks[task](config, task)
36+
return tasks[task](config)
3737

3838

3939
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
@@ -45,4 +45,4 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
4545
"""
4646
tasks = {mod.__TASK__: mod.random_input_kwargs for mod in __TASKS__}
4747
assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
48-
return tasks[task](config, task)
48+
return tasks[task](config)

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
__TASK__ = "automatic-speech-recognition"
88

99

10-
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
10+
def reduce_model_config(config: Any) -> Dict[str, Any]:
1111
"""Reduces a model size."""
1212
kwargs: Dict[str, Any] = {}
1313
if hasattr(config, "num_decoder_layers"):
@@ -129,7 +129,7 @@ def get_inputs(
129129
return dict(inputs=inputs, dynamic_shapes=shapes)
130130

131131

132-
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
132+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
133133
"""
134134
Inputs kwargs.
135135

0 commit comments

Comments
 (0)