Skip to content

Commit cc0c5f2

Browse files
committed
test
1 parent 02a6e83 commit cc0c5f2

File tree

5 files changed

+427
-23
lines changed

5 files changed

+427
-23
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
import onnx
44
import torch
55
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
6+
from onnx_diagnostic.helpers import string_type
67
from onnx_diagnostic.helpers.torch_test_helper import (
78
dummy_llm,
89
to_numpy,
910
is_torchdynamo_exporting,
1011
steel_forward,
1112
replace_string_by_dynamic,
13+
to_any,
14+
torch_deepcopy,
1215
)
16+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
1317

1418
TFLOAT = onnx.TensorProto.FLOAT
1519

1620

17-
class TestOrtSession(ExtTestCase):
21+
class TestTorchTestHelper(ExtTestCase):
1822

1923
def test_is_torchdynamo_exporting(self):
2024
self.assertFalse(is_torchdynamo_exporting())
@@ -67,6 +71,29 @@ def test_replace_string_by_dynamic(self):
6771
sproc,
6872
)
6973

74+
def test_to_any(self):
75+
c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
76+
c2 = make_encoder_decoder_cache(
77+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
78+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
79+
)
80+
a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]}
81+
at = to_any(a, torch.float16)
82+
self.assertIn("T10r", string_type(at))
83+
84+
def test_torch_deepcopy(self):
85+
c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
86+
c2 = make_encoder_decoder_cache(
87+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
88+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
89+
)
90+
a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]}
91+
at = torch_deepcopy(a)
92+
hash1 = string_type(at, with_shape=True, with_min_max=True)
93+
c1.key_cache[0] += 1000
94+
hash2 = string_type(at, with_shape=True, with_min_max=True)
95+
self.assertEqual(hash1, hash2)
96+
7097

7198
if __name__ == "__main__":
7299
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import unittest
3-
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
44
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task, validate_model
55
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
66

@@ -15,13 +15,39 @@ def test_get_inputs_for_task(self):
1515
self.assertIn("dynamic_shapes", data)
1616
copy.deepcopy(data["inputs"])
1717

18+
@hide_stdout()
1819
def test_validate_model(self):
1920
mid = "arnir0/Tiny-LLM"
2021
summary, data = validate_model(mid, do_run=True, verbose=2)
2122
self.assertIsInstance(summary, dict)
2223
self.assertIsInstance(data, dict)
2324
validate_model(mid, do_run=True, verbose=2, quiet=True)
2425

26+
@hide_stdout()
27+
def test_validate_model_dtype(self):
28+
mid = "arnir0/Tiny-LLM"
29+
summary, data = validate_model(
30+
mid, do_run=True, verbose=2, dtype="float32", device="cpu"
31+
)
32+
self.assertIsInstance(summary, dict)
33+
self.assertIsInstance(data, dict)
34+
validate_model(mid, do_run=True, verbose=2, quiet=True)
35+
36+
@hide_stdout()
37+
def test_validate_model_export(self):
38+
mid = "arnir0/Tiny-LLM"
39+
summary, data = validate_model(
40+
mid,
41+
do_run=True,
42+
verbose=2,
43+
dtype="float32",
44+
device="cpu",
45+
exporter="export-nostrict",
46+
)
47+
self.assertIsInstance(summary, dict)
48+
self.assertIsInstance(data, dict)
49+
validate_model(mid, do_run=True, verbose=2, quiet=False)
50+
2551

2652
if __name__ == "__main__":
2753
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,12 @@ def get_parser_validate() -> ArgumentParser:
254254
parser.add_argument(
255255
"-e",
256256
"--export",
257-
default=False,
258-
action=BooleanOptionalAction,
259-
help="runs the model to check it runs",
257+
help="export the model with this exporter",
258+
)
259+
parser.add_argument(
260+
"-o",
261+
"--opt",
262+
help="optimization to apply after the export",
260263
)
261264
parser.add_argument(
262265
"-r",
@@ -272,18 +275,22 @@ def get_parser_validate() -> ArgumentParser:
272275
action=BooleanOptionalAction,
273276
help="catches exception, report them in the summary",
274277
)
278+
parser.add_argument(
279+
"-p",
280+
"--patch",
281+
default=True,
282+
action=BooleanOptionalAction,
283+
help="applies patches before exporting",
284+
)
275285
parser.add_argument(
276286
"--trained",
277287
default=False,
278288
action=BooleanOptionalAction,
279289
help="validate the trained model (requires downloading)",
280290
)
281-
parser.add_argument(
282-
"-v",
283-
"--verbose",
284-
default=0,
285-
help="verbosity",
286-
)
291+
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
292+
parser.add_argument("--dtype", help="changes dtype if necessary")
293+
parser.add_argument("--device", help="changes the device if necessary")
287294
return parser
288295

289296

@@ -316,9 +323,14 @@ def _cmd_validate(argv: List[Any]):
316323
verbose=args.verbose,
317324
quiet=args.quiet,
318325
trained=args.trained,
326+
dtype=args.dtype,
327+
device=args.device,
328+
patch=args.patch,
329+
optimization=args.opt,
330+
exporter=args.export,
319331
)
320332
print("")
321-
print("-- summary")
333+
print("-- summary --")
322334
for k, v in sorted(summary.items()):
323335
print(f":{k},{v};")
324336

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import contextlib
2+
from collections.abc import Iterable
23
from typing import Any, Optional, Tuple, Union
4+
import numpy as np
35
import torch
46
from .helper import string_type
7+
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
58

69

710
def _forward_(*args, _f=None, _context=None, **kwargs):
@@ -298,3 +301,66 @@ def forward(self, input_ids):
298301
return dec, (x,)
299302

300303
raise NotImplementedError(f"cls_name={cls_name}")
304+
305+
306+
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
307+
"""
308+
Applies torch.to is applicables.
309+
Goes recursively.
310+
"""
311+
if isinstance(value, (torch.nn.Module, torch.Tensor)):
312+
return value.to(to_value)
313+
if isinstance(value, list):
314+
return [to_any(t, to_value) for t in value]
315+
if isinstance(value, tuple):
316+
return tuple(to_any(t, to_value) for t in value)
317+
if isinstance(value, set):
318+
return {to_any(t, to_value) for t in value}
319+
if isinstance(value, dict):
320+
return {k: to_any(t, to_value) for k, t in value.items()}
321+
if hasattr(value, "to"):
322+
return value.to(to_value)
323+
if value.__class__.__name__ == "DynamicCache":
324+
return make_dynamic_cache(
325+
list(
326+
zip(
327+
[t.to(to_value) for t in value.key_cache],
328+
[t.to(to_value) for t in value.value_cache],
329+
)
330+
)
331+
)
332+
333+
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
334+
return value
335+
336+
337+
def torch_deepcopy(value: Any) -> Any:
338+
"""
339+
Makes a deepcopy.
340+
"""
341+
if isinstance(value, (int, float, str)):
342+
return value
343+
if isinstance(value, tuple):
344+
return tuple(torch_deepcopy(v) for v in value)
345+
if isinstance(value, list):
346+
return [torch_deepcopy(v) for v in value]
347+
if isinstance(value, set):
348+
return {torch_deepcopy(v) for v in value}
349+
if isinstance(value, dict):
350+
return {k: torch_deepcopy(v) for k, v in value.items()}
351+
if isinstance(value, np.ndarray):
352+
return value.copy()
353+
if hasattr(value, "clone"):
354+
return value.clone()
355+
if value.__class__.__name__ == "DynamicCache":
356+
return make_dynamic_cache(
357+
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
358+
)
359+
if value.__class__.__name__ == "EncoderDecoderCache":
360+
return make_encoder_decoder_cache(
361+
torch_deepcopy(value.self_attention_cache),
362+
torch_deepcopy(value.cross_attention_cache),
363+
)
364+
# We should have a code using serialization, deserialization assuming a model
365+
# cannot be exported without them.
366+
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")

0 commit comments

Comments
 (0)