Skip to content

Commit fdced7a

Browse files
committed
val
1 parent 1512c22 commit fdced7a

File tree

6 files changed

+164
-29
lines changed

6 files changed

+164
-29
lines changed

_unittests/ut_helpers/test_memory_peak.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,21 @@ def test_spy_cpu(self):
3131
time.sleep(0.005)
3232
n_elements = max(value.shape[0], n_elements)
3333
time.sleep(0.02)
34-
pres = p.stop()
34+
measures = p.stop()
3535
self.assertGreater(n_elements, 0)
36-
self.assertIsInstance(pres, dict)
37-
self.assertLessEqual(pres["cpu"].end, pres["cpu"].max_peak)
38-
self.assertLessEqual(pres["cpu"].begin, pres["cpu"].max_peak)
39-
self.assertGreater(pres["cpu"].begin, 0)
36+
self.assertIsInstance(measures, dict)
37+
self.assertLessEqual(measures["cpu"].end, measures["cpu"].max_peak)
38+
self.assertLessEqual(measures["cpu"].begin, measures["cpu"].max_peak)
39+
self.assertGreater(measures["cpu"].begin, 0)
4040
# Zero should not happen...
41-
self.assertGreaterOrEqual(pres["cpu"].delta_peak, 0)
42-
self.assertGreaterOrEqual(pres["cpu"].delta_peak, pres["cpu"].delta_end)
43-
self.assertGreaterOrEqual(pres["cpu"].delta_peak, pres["cpu"].delta_avg)
44-
self.assertGreaterOrEqual(pres["cpu"].delta_end, 0)
45-
self.assertGreaterOrEqual(pres["cpu"].delta_avg, 0)
41+
self.assertGreaterOrEqual(measures["cpu"].delta_peak, 0)
42+
self.assertGreaterOrEqual(measures["cpu"].delta_peak, measures["cpu"].delta_end)
43+
self.assertGreaterOrEqual(measures["cpu"].delta_peak, measures["cpu"].delta_avg)
44+
self.assertGreaterOrEqual(measures["cpu"].delta_end, 0)
45+
self.assertGreaterOrEqual(measures["cpu"].delta_avg, 0)
4646
# Too unstable.
47-
# self.assertGreater(pres["cpu"].delta_peak, n_elements * 8 * 0.5)
48-
self.assertIsInstance(pres["cpu"].to_dict(), dict)
47+
# self.assertGreater(measures["cpu"].delta_peak, n_elements * 8 * 0.5)
48+
self.assertIsInstance(measures["cpu"].to_dict(), dict)
4949

5050
@skipif_ci_apple("stuck")
5151
@requires_cuda()
@@ -58,10 +58,10 @@ def test_spy_cuda(self):
5858
value += 1
5959
n_elements = max(value.shape[0], n_elements)
6060
time.sleep(0.02)
61-
pres = p.stop()
62-
self.assertIsInstance(pres, dict)
63-
self.assertIn("gpus", pres)
64-
gpu = pres["gpus"][0]
61+
measures = p.stop()
62+
self.assertIsInstance(measures, dict)
63+
self.assertIn("gpus", measures)
64+
gpu = measures["gpus"][0]
6565
self.assertLessEqual(gpu.end, gpu.max_peak)
6666
self.assertLessEqual(gpu.begin, gpu.max_peak)
6767
self.assertGreater(gpu.delta_peak, 0)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import unittest
33
from onnx_diagnostic.ext_test_case import ExtTestCase
4-
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task
4+
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task, validate_model
55
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
66

77

@@ -15,6 +15,13 @@ def test_get_inputs_for_task(self):
1515
self.assertIn("dynamic_shapes", data)
1616
copy.deepcopy(data["inputs"])
1717

18+
def test_validate_model(self):
19+
mid = "arnir0/Tiny-LLM"
20+
summary, data = validate_model(mid, do_run=True, verbose=2)
21+
self.assertIsInstance(summary, dict)
22+
self.assertIsInstance(data, dict)
23+
validate_model(mid, do_run=True, verbose=2, quiet=True)
24+
1825

1926
if __name__ == "__main__":
2027
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,16 @@ def test_parser_config(self):
3737
def test_parser_validate(self):
3838
st = StringIO()
3939
with redirect_stdout(st):
40+
main(["validate"])
4041
main(["validate", "-t", "text-generation"])
4142
text = st.getvalue()
4243
self.assertIn("dynamic_shapes", text)
44+
st = StringIO()
45+
with redirect_stdout(st):
46+
main(["validate"])
47+
main(["validate", "-m", "arnir0/Tiny-LLM", "--run", "-v", "1"])
48+
text = st.getvalue()
49+
self.assertIn("model_clas", text)
4350

4451

4552
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def get_parser_validate() -> ArgumentParser:
233233
description=dedent(
234234
"""
235235
Prints out dummy inputs for a particular task or a model id.
236+
If both mid and task are empty, the command line displays the list
237+
of supported tasks.
236238
"""
237239
),
238240
epilog="If the model id is specified, one untrained version of it is instantiated.",
@@ -263,6 +265,19 @@ def get_parser_validate() -> ArgumentParser:
263265
action=BooleanOptionalAction,
264266
help="runs the model to check it runs",
265267
)
268+
parser.add_argument(
269+
"-q",
270+
"--quiet",
271+
default=False,
272+
action=BooleanOptionalAction,
273+
help="catches exception, report them in the summary",
274+
)
275+
parser.add_argument(
276+
"--trained",
277+
default=False,
278+
action=BooleanOptionalAction,
279+
help="validate the trained model (requires downloading)",
280+
)
266281
parser.add_argument(
267282
"-v",
268283
"--verbose",
@@ -274,12 +289,15 @@ def get_parser_validate() -> ArgumentParser:
274289

275290
def _cmd_validate(argv: List[Any]):
276291
from .helpers import string_type
277-
from .torch_models.test_helper import get_inputs_for_task
292+
from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
293+
from .torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
278294

279295
parser = get_parser_validate()
280296
args = parser.parse_args(argv[1:])
281-
assert args.task or args.mid, "A model id or a task needs to be specified."
282-
if not args.mid:
297+
if not args.task and not args.mid:
298+
print("-- list of supported tasks:")
299+
print("\n".join(sorted(get_get_inputs_function_for_tasks())))
300+
elif not args.mid:
283301
data = get_inputs_for_task(args.task)
284302
if args.verbose:
285303
print(f"task: {args.task}")
@@ -289,10 +307,20 @@ def _cmd_validate(argv: List[Any]):
289307
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
290308
print("-- dynamic_shapes")
291309
for k, v in data["dynamic_shapes"].items():
292-
vs = str(v).replace("<class 'onnx_diagnostic.torch_models.hghub.model_inputs.", "").replace("'>", "").replace("_DimHint(type=<_DimHintType.DYNAMIC: 3>", "DYNAMIC").replace("_DimHint(type=<_DimHintType.AUTO: 3>", "AUTO")
293-
print(f" + {k.ljust(max_length)}: {vs}")
294-
295-
# validate_model(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
310+
print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
311+
else:
312+
summary, _data = validate_model(
313+
model_id=args.mid,
314+
task=args.task,
315+
do_run=args.run,
316+
verbose=args.verbose,
317+
quiet=args.quiet,
318+
trained=args.trained,
319+
)
320+
print("")
321+
print("-- summary")
322+
for k, v in sorted(summary.items()):
323+
print(f":{k},{v};")
296324

297325

298326
def get_main_parser() -> ArgumentParser:

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
147147
If the configuration is None, the function selects typical dimensions.
148148
"""
149149
fcts = get_get_inputs_function_for_tasks()
150-
assert task in fcts, f"Unsupported task {task!r}, supprted are {sorted(fcts)}"
150+
assert task in fcts, f"Unsupported task {task!r}, supported are {sorted(fcts)}"
151151
if task == "text-generation":
152152
if config is not None:
153153
check_hasattr(
@@ -376,6 +376,7 @@ def get_untrained_model_with_inputs(
376376
res["configuration"] = config
377377
res["size"] = sizes[0]
378378
res["n_weights"] = sizes[1]
379+
res["task"] = task
379380

380381
update = {}
381382
for k, v in res.items():

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
1-
from typing import Any, Dict, Optional, Union
1+
import torch
2+
from typing import Any, Dict, Optional, Tuple, Union
3+
import time
4+
from ..helpers import string_type
5+
from .hghub import get_untrained_model_with_inputs
26
from .hghub.model_inputs import random_input_kwargs
37

48

9+
def _ds_clean(v):
10+
return (
11+
str(v)
12+
.replace("<class 'onnx_diagnostic.torch_models.hghub.model_inputs.", "")
13+
.replace("'>", "")
14+
.replace("_DimHint(type=<_DimHintType.DYNAMIC: 3>", "DYNAMIC")
15+
.replace("_DimHint(type=<_DimHintType.AUTO: 3>", "AUTO")
16+
)
17+
18+
519
def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]:
620
"""
721
Returns dummy inputs for a specific task.
@@ -18,12 +32,90 @@ def validate_model(
1832
model_id: str,
1933
task: Optional[str] = None,
2034
do_run: bool = False,
21-
do_export: bool = False,
35+
exporter: Optional[str] = None,
2236
do_same: bool = False,
2337
verbose: int = 0,
24-
) -> Dict[str, Union[int, float, str]]:
38+
dtype: Optional[Union[str, torch.dtype]] = None,
39+
device: Optional[Union[str, torch.device]] = None,
40+
trained: bool = False,
41+
optimization: Optional[str] = None,
42+
quiet: bool = False,
43+
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
2544
"""
2645
Validates a model.
2746
28-
47+
:param model_id: model id to validate
48+
:param task: task used to generate the necessary inputs,
49+
can be left empty to use the default task for this model
50+
if it can be determined
51+
:param do_run: checks the model works with the defined inputs
52+
:param exporter: exporter the model using this exporter,
53+
available list: ``export-strict``, ``export-nostrict``, ``onnx``
54+
:param do_same: checks the discrepancies of the exported model
55+
:param verbose: verbosity level
56+
:param dtype: uses this dtype to check the model
57+
:param device: do the verification on this device
58+
:param trained: use the trained model, not the untrained one
59+
:param optimization: optimization to apply to the exported model,
60+
depend on the the exporter
61+
:param quiet: if quiet, catches exception if any issue
62+
:return: two dictionaries, one with some metrics,
63+
another one with whatever the function produces
2964
"""
65+
assert not trained, f"trained={trained} not supported yet"
66+
assert not dtype, f"dtype={dtype} not supported yet"
67+
assert not device, f"device={device} not supported yet"
68+
summary = {}
69+
if verbose:
70+
print(f"[validate_model] validate model id {model_id!r}")
71+
print("[validate_model] get dummy inputs...")
72+
summary["model_id"] = model_id
73+
begin = time.perf_counter()
74+
if quiet:
75+
try:
76+
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
77+
except Exception as e:
78+
summary["ERR_create"] = e
79+
summary["time_create"] = time.perf_counter() - begin
80+
return summary, {}
81+
else:
82+
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
83+
summary["time_create"] = time.perf_counter() - begin
84+
for k in ["task", "size", "n_weights"]:
85+
summary[f"model_{k.replace('_','')}"] = data[k]
86+
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
87+
summary["model_shapes"] = _ds_clean(str(data["dynamic_shapes"]))
88+
summary["model_class"] = data["model"].__class__.__name__
89+
summary["model_config_class"] = data["configuration"].__class__.__name__
90+
summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
91+
summary["model_id"] = model_id
92+
if verbose:
93+
print(f"[validate_model] task={data["task"]}")
94+
print(f"[validate_model] size={data["size"]}")
95+
print(f"[validate_model] n_weights={data["n_weights"]}")
96+
print(f"[validate_model] n_weights={data["n_weights"]}")
97+
for k, v in data["inputs"].items():
98+
print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
99+
for k, v in data["dynamic_shapes"].items():
100+
print(f"[validate_model] +SHAPE {k}={_ds_clean(v)}")
101+
if do_run:
102+
if verbose:
103+
print("[validate_model] run the model...")
104+
begin = time.perf_counter()
105+
if quiet:
106+
try:
107+
expected = data["model"](**data["inputs"])
108+
except Exception as e:
109+
summary["ERR_run"] = e
110+
summary["time_run"] = time.perf_counter() - begin
111+
return summary, data
112+
else:
113+
expected = data["model"](**data["inputs"])
114+
summary["time_run"] = time.perf_counter() - begin
115+
summary["model_expected"] = string_type(expected, with_shape=True)
116+
if verbose:
117+
print("[validate_model] run the model")
118+
data["expected"] = expected
119+
if verbose:
120+
print("[validate_model] done.")
121+
return summary, data

0 commit comments

Comments
 (0)