Skip to content

Commit 1512c22

Browse files
committed
add cmd
1 parent 0ce5a6a commit 1512c22

File tree

9 files changed

+210
-73
lines changed

9 files changed

+210
-73
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`30`: adds command to test a model id
78
* :pr:`29`: adds helpers to measure the memory peak and run benchmark
89
on different processes
910
* :pr:`28`: adds command line to print out the configuration for a model id,

_doc/api/torch_models/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ onnx_diagnostic.torch_models
77

88
hghub/index
99
llms
10+
test_helper
1011

1112
.. automodule:: onnx_diagnostic.torch_models
1213
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.test_helper
3+
========================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.test_helper
6+
:members:
7+
:no-undoc-members:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import copy
2+
import unittest
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task
5+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
6+
7+
8+
class TestTestHelper(ExtTestCase):
9+
def test_get_inputs_for_task(self):
10+
fcts = get_get_inputs_function_for_tasks()
11+
for task in self.subloop(sorted(fcts)):
12+
data = get_inputs_for_task(task)
13+
self.assertIsInstance(data, dict)
14+
self.assertIn("inputs", data)
15+
self.assertIn("dynamic_shapes", data)
16+
copy.deepcopy(data["inputs"])
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from onnx_diagnostic.ext_test_case import ExtTestCase
55
from onnx_diagnostic._command_lines_parser import (
66
get_main_parser,
7+
get_parser_config,
78
get_parser_find,
89
get_parser_lighten,
910
get_parser_print,
1011
get_parser_unlighten,
11-
get_parser_config,
12+
get_parser_validate,
1213
)
1314

1415

@@ -55,6 +56,13 @@ def test_parser_config(self):
5556
text = st.getvalue()
5657
self.assertIn("mid", text)
5758

59+
def test_parser_validate(self):
60+
st = StringIO()
61+
with redirect_stdout(st):
62+
get_parser_validate().print_help()
63+
text = st.getvalue()
64+
self.assertIn("mid", text)
65+
5866

5967
if __name__ == "__main__":
6068
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def test_parser_config(self):
3434
text = st.getvalue()
3535
self.assertIn("LlamaForCausalLM", text)
3636

37+
def test_parser_validate(self):
38+
st = StringIO()
39+
with redirect_stdout(st):
40+
main(["validate", "-t", "text-generation"])
41+
text = st.getvalue()
42+
self.assertIn("dynamic_shapes", text)
43+
3744

3845
if __name__ == "__main__":
3946
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,22 +227,20 @@ def _cmd_config(argv: List[Any]):
227227
print(f"task: {task_from_id(args.mid)}")
228228

229229

230-
def get_parser_inputs() -> ArgumentParser:
230+
def get_parser_validate() -> ArgumentParser:
231231
parser = ArgumentParser(
232232
prog="test",
233233
description=dedent(
234234
"""
235235
Prints out dummy inputs for a particular task or a model id.
236236
"""
237237
),
238-
epilog="If the model id is specified, one untrained "
239-
"version of it is instantiated.",
238+
epilog="If the model id is specified, one untrained version of it is instantiated.",
240239
)
241240
parser.add_argument(
242241
"-m",
243242
"--mid",
244243
type=str,
245-
required=True,
246244
help="model id, usually <author>/<name>",
247245
)
248246
parser.add_argument(
@@ -274,6 +272,29 @@ def get_parser_inputs() -> ArgumentParser:
274272
return parser
275273

276274

275+
def _cmd_validate(argv: List[Any]):
276+
from .helpers import string_type
277+
from .torch_models.test_helper import get_inputs_for_task
278+
279+
parser = get_parser_validate()
280+
args = parser.parse_args(argv[1:])
281+
assert args.task or args.mid, "A model id or a task needs to be specified."
282+
if not args.mid:
283+
data = get_inputs_for_task(args.task)
284+
if args.verbose:
285+
print(f"task: {args.task}")
286+
max_length = max(len(k) for k in data["inputs"]) + 1
287+
print("-- inputs")
288+
for k, v in data["inputs"].items():
289+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
290+
print("-- dynamic_shapes")
291+
for k, v in data["dynamic_shapes"].items():
292+
vs = str(v).replace("<class 'onnx_diagnostic.torch_models.hghub.model_inputs.", "").replace("'>", "").replace("_DimHint(type=<_DimHintType.DYNAMIC: 3>", "DYNAMIC").replace("_DimHint(type=<_DimHintType.AUTO: 3>", "AUTO")
293+
print(f" + {k.ljust(max_length)}: {vs}")
294+
295+
# validate_model(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
296+
297+
277298
def get_main_parser() -> ArgumentParser:
278299
parser = ArgumentParser(
279300
prog="onnx_diagnostic",
@@ -289,13 +310,13 @@ def get_main_parser() -> ArgumentParser:
289310
lighten - makes an onnx model lighter by removing the weights,
290311
unlighten - restores an onnx model produces by the previous experiment
291312
print - prints the model on standard output
292-
test - tests a model
313+
validate - validate a model
293314
"""
294315
),
295316
)
296317
parser.add_argument(
297318
"cmd",
298-
choices=["config", "find", "lighten", "print", "unlighten", "test"],
319+
choices=["config", "find", "lighten", "print", "unlighten", "validate"],
299320
help="Selects a command.",
300321
)
301322
return parser
@@ -308,7 +329,7 @@ def main(argv: Optional[List[Any]] = None):
308329
print=_cmd_print,
309330
find=_cmd_find,
310331
config=_cmd_config,
311-
text=_cmd_test,
332+
validate=_cmd_validate,
312333
)
313334

314335
if argv is None:
@@ -328,6 +349,7 @@ def main(argv: Optional[List[Any]] = None):
328349
print=get_parser_print,
329350
find=get_parser_find,
330351
config=get_parser_config,
352+
validate=get_parser_validate,
331353
)
332354
cmd = argv[0]
333355
if cmd not in parsers:

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 107 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -141,69 +141,92 @@ def _pick(config, *atts):
141141

142142

143143
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
144-
"""Inputs kwargs"""
144+
"""
145+
Inputs kwargs.
146+
147+
If the configuration is None, the function selects typical dimensions.
148+
"""
149+
fcts = get_get_inputs_function_for_tasks()
150+
assert task in fcts, f"Unsupported task {task!r}, supprted are {sorted(fcts)}"
145151
if task == "text-generation":
146-
check_hasattr(
147-
config,
148-
"vocab_size",
149-
"hidden_size",
150-
"num_attention_heads",
151-
("num_key_value_heads", "num_attention_heads"),
152-
"intermediate_size",
153-
"hidden_size",
154-
)
152+
if config is not None:
153+
check_hasattr(
154+
config,
155+
"vocab_size",
156+
"hidden_size",
157+
"num_attention_heads",
158+
("num_key_value_heads", "num_attention_heads"),
159+
"intermediate_size",
160+
"hidden_size",
161+
)
155162
kwargs = dict(
156163
batch_size=2,
157164
sequence_length=30,
158165
sequence_length2=3,
159-
head_dim=getattr(
160-
config, "head_dim", config.hidden_size // config.num_attention_heads
166+
head_dim=(
167+
16
168+
if config is None
169+
else getattr(
170+
config, "head_dim", config.hidden_size // config.num_attention_heads
171+
)
161172
),
162-
dummy_max_token_id=config.vocab_size - 1,
163-
num_hidden_layers=config.num_hidden_layers,
164-
num_key_value_heads=_pick(config, "num_key_value_heads", "num_attention_heads"),
165-
intermediate_size=config.intermediate_size,
166-
hidden_size=config.hidden_size,
173+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
174+
num_hidden_layers=4 if config is None else config.num_hidden_layers,
175+
num_key_value_heads=(
176+
24
177+
if config is None
178+
else _pick(config, "num_key_value_heads", "num_attention_heads")
179+
),
180+
intermediate_size=1024 if config is None else config.intermediate_size,
181+
hidden_size=512 if config is None else config.hidden_size,
167182
)
168183
fct = get_inputs_for_text_generation
169184
elif task == "text2text-generation":
170-
check_hasattr(
171-
config,
172-
"vocab_size",
173-
"hidden_size",
174-
"num_attention_heads",
175-
("num_hidden_layers", "num_layers"),
176-
("n_positions", "d_model"),
177-
(
178-
"num_key_value_heads",
179-
"num_heads",
180-
("decoder_attention_heads", "encoder_attention_heads"),
181-
),
182-
)
185+
if config is not None:
186+
check_hasattr(
187+
config,
188+
"vocab_size",
189+
"hidden_size",
190+
"num_attention_heads",
191+
("num_hidden_layers", "num_layers"),
192+
("n_positions", "d_model"),
193+
(
194+
"num_key_value_heads",
195+
"num_heads",
196+
("decoder_attention_heads", "encoder_attention_heads"),
197+
),
198+
)
183199
kwargs = dict(
184200
batch_size=2,
185201
sequence_length=30,
186202
sequence_length2=3,
187-
head_dim=config.d_kv if hasattr(config, "d_kv") else 1,
188-
dummy_max_token_id=config.vocab_size - 1,
189-
num_hidden_layers=_pick(config, "num_hidden_layers", "num_layers"),
190-
num_key_value_heads=_pick(
191-
config,
192-
"num_key_value_heads",
193-
"num_heads",
194-
(sum, "encoder_attention_heads", "decoder_attention_heads"),
203+
head_dim=16 if config is None else (config.d_kv if hasattr(config, "d_kv") else 1),
204+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
205+
num_hidden_layers=(
206+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
207+
),
208+
num_key_value_heads=(
209+
16
210+
if config is None
211+
else _pick(
212+
config,
213+
"num_key_value_heads",
214+
"num_heads",
215+
(sum, "encoder_attention_heads", "decoder_attention_heads"),
216+
)
195217
),
196-
encoder_dim=_pick(config, "n_positions", "d_model"),
218+
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
197219
)
198220
fct = get_inputs_for_text2text_generation # type: ignore
199221
elif task == "image-classification":
200-
check_hasattr(config, "image_size", "num_channels")
201-
if isinstance(config.image_size, int):
222+
if config is not None:
223+
check_hasattr(config, "image_size", "num_channels")
224+
if config is None or isinstance(config.image_size, int):
202225
kwargs = dict(
203226
batch_size=2,
204-
input_width=config.image_size,
205-
input_height=config.image_size,
206-
input_channels=config.num_channels,
227+
input_width=224 if config is None else config.image_size,
228+
input_height=224 if config is None else config.image_size,
229+
input_channels=3 if config is None else config.num_channels,
207230
)
208231
else:
209232
kwargs = dict(
@@ -214,32 +237,41 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
214237
)
215238
fct = get_inputs_for_image_classification # type: ignore
216239
elif task == "image-text-to-text":
217-
check_hasattr(
218-
config,
219-
"vocab_size",
220-
"hidden_size",
221-
"num_attention_heads",
222-
("num_key_value_heads", "num_attention_heads"),
223-
"intermediate_size",
224-
"hidden_size",
225-
"vision_config",
226-
)
227-
check_hasattr(config.vision_config, "image_size", "num_channels")
240+
if config is not None:
241+
check_hasattr(
242+
config,
243+
"vocab_size",
244+
"hidden_size",
245+
"num_attention_heads",
246+
("num_key_value_heads", "num_attention_heads"),
247+
"intermediate_size",
248+
"hidden_size",
249+
"vision_config",
250+
)
251+
check_hasattr(config.vision_config, "image_size", "num_channels")
228252
kwargs = dict(
229253
batch_size=2,
230254
sequence_length=30,
231255
sequence_length2=3,
232-
head_dim=getattr(
233-
config, "head_dim", config.hidden_size // config.num_attention_heads
256+
head_dim=(
257+
16
258+
if config is None
259+
else getattr(
260+
config, "head_dim", config.hidden_size // config.num_attention_heads
261+
)
262+
),
263+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
264+
num_hidden_layers=4 if config is None else config.num_hidden_layers,
265+
num_key_value_heads=(
266+
8
267+
if config is None
268+
else _pick(config, "num_key_value_heads", "num_attention_heads")
234269
),
235-
dummy_max_token_id=config.vocab_size - 1,
236-
num_hidden_layers=config.num_hidden_layers,
237-
num_key_value_heads=_pick(config, "num_key_value_heads", "num_attention_heads"),
238-
intermediate_size=config.intermediate_size,
239-
hidden_size=config.hidden_size,
240-
width=config.vision_config.image_size,
241-
height=config.vision_config.image_size,
242-
num_channels=config.vision_config.num_channels,
270+
intermediate_size=1024 if config is None else config.intermediate_size,
271+
hidden_size=512 if config is None else config.hidden_size,
272+
width=224 if config is None else config.vision_config.image_size,
273+
height=224 if config is None else config.vision_config.image_size,
274+
num_channels=3 if config is None else config.vision_config.num_channels,
243275
)
244276
fct = get_inputs_for_image_text_to_text # type: ignore
245277
else:
@@ -682,3 +714,13 @@ def get_inputs_for_text2text_generation(
682714
# encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
683715
)
684716
return dict(inputs=inputs, dynamic_shapes=shapes)
717+
718+
719+
def get_get_inputs_function_for_tasks() -> Dict[str, Callable]:
720+
"""Returns all the function producing dummy inputs for every task."""
721+
return {
722+
"image-classification": get_inputs_for_image_classification,
723+
"text-generation": get_inputs_for_text_generation,
724+
"text2text-generation": get_inputs_for_text2text_generation,
725+
"image-text-to-text": get_inputs_for_image_text_to_text,
726+
}

0 commit comments

Comments
 (0)