Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions _unittests/ut_torch_export_patches/test_onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
requires_transformers,
skipif_ci_windows,
ignore_warnings,
hide_stdout,
)
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
Expand All @@ -16,6 +17,7 @@ class TestOnnxExportErrors(ExtTestCase):
@requires_transformers("4.49.999")
@skipif_ci_windows("not working on Windows")
@ignore_warnings(UserWarning)
@hide_stdout()
def test_pytree_flatten_mamba_cache(self):
import torch
import torch.utils._pytree as py_pytree
Expand All @@ -31,7 +33,7 @@ def __init__(self):

cache = MambaCache(_config(), max_batch_size=1, device="cpu")

with bypass_export_some_errors():
with bypass_export_some_errors(verbose=1):
values, spec = py_pytree.tree_flatten(cache)
cache2 = py_pytree.tree_unflatten(values, spec)
self.assertEqual(cache.dtype, cache2.dtype)
Expand All @@ -46,6 +48,7 @@ def __init__(self):
@requires_torch("2.7")
@skipif_ci_windows("not working on Windows")
@ignore_warnings(UserWarning)
@hide_stdout()
def test_exportable_mamba_cache(self):
import torch
from transformers.models.mamba.modeling_mamba import MambaCache
Expand Down Expand Up @@ -73,7 +76,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
model = Model()
model(x, cache)

with bypass_export_some_errors():
with bypass_export_some_errors(replace_dynamic_cache=True, verbose=1):
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
torch.export.export(Model(), (x, cache))

Expand Down
24 changes: 24 additions & 0 deletions _unittests/ut_xrun_doc/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ def test_args(self):
self.assertEqual(args.with_mask, 0)
self.assertEqual(args.optim, "")

def test_args_expose(self):
try:
args = get_parsed_args(
"plot_custom_backend_llama",
config=(
"medium",
"large or medium depending, large means closer to the real model",
),
num_hidden_layers=(1, "number of hidden layers"),
with_mask=(0, "tries with a mask as a secondary input"),
optim=("", "Optimization to apply, empty string for all"),
description="doc",
new_args=["--config", "m"],
expose="repeat,warmup",
)
except SystemExit as e:
raise AssertionError(f"SystemExist caught: {e}")
self.assertEqual(args.config, "m")
self.assertEqual(args.num_hidden_layers, 1)
self.assertEqual(args.with_mask, 0)
self.assertEqual(args.optim, "")
self.assertEqual(args.repeat, 10)
self.assertEqual(args.warmup, 5)


if __name__ == "__main__":
unittest.main(verbosity=2)
51 changes: 49 additions & 2 deletions _unittests/ut_xrun_doc/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import onnx
import onnx.helper as oh
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows, hide_stdout
from onnx_diagnostic.helpers import (
string_type,
string_sig,
Expand Down Expand Up @@ -127,7 +127,22 @@ def test_flatten(self):
diff = max_diff(inputs, flat, flatten=True)
self.assertEqual(diff["abs"], 0)
d = string_diff(diff)
print(d)
self.assertIsInstance(d, str)

@hide_stdout()
def test_max_diff_verbose(self):
inputs = (
torch.rand((3, 4), dtype=torch.float16),
[
torch.rand((5, 6), dtype=torch.float16),
torch.rand((5, 6, 7), dtype=torch.float16),
],
)
flat = flatten_object(inputs)
diff = max_diff(inputs, flat, flatten=True, verbose=10)
self.assertEqual(diff["abs"], 0)
d = string_diff(diff)
self.assertIsInstance(d, str)

def test_type_info(self):
for tt in [
Expand Down Expand Up @@ -250,6 +265,38 @@ def test_string_type_one(self):
self.assertEqual(string_type([4] * 100), "#100[int,...]")
self.assertEqual(string_type((4,) * 100), "#100(int,...)")

def test_string_type_one_with_min_max_int(self):
self.assertEqual(string_type(None, with_min_max=True), "None")
self.assertEqual(string_type([4], with_min_max=True), "#1[int=4]")
self.assertEqual(string_type((4, 5), with_min_max=True), "(int=4,int=5)")
self.assertEqual(string_type([4] * 100, with_min_max=True), "#100[int=4,...][4,4:4.0]")
self.assertEqual(
string_type((4,) * 100, with_min_max=True), "#100(int=4,...)[4,4:A[4.0]]"
)

def test_string_type_one_with_min_max_bool(self):
self.assertEqual(string_type(None, with_min_max=True), "None")
self.assertEqual(string_type([True], with_min_max=True), "#1[bool=True]")
self.assertEqual(string_type((True, True), with_min_max=True), "(bool=True,bool=True)")
self.assertEqual(
string_type([True] * 100, with_min_max=True), "#100[bool=True,...][True,True:1.0]"
)
self.assertEqual(
string_type((True,) * 100, with_min_max=True),
"#100(bool=True,...)[True,True:A[1.0]]",
)

def test_string_type_one_with_min_max_float(self):
self.assertEqual(string_type(None, with_min_max=True), "None")
self.assertEqual(string_type([4.5], with_min_max=True), "#1[float=4.5]")
self.assertEqual(string_type((4.5, 5.5), with_min_max=True), "(float=4.5,float=5.5)")
self.assertEqual(
string_type([4.5] * 100, with_min_max=True), "#100[float=4.5,...][4.5,4.5:4.5]"
)
self.assertEqual(
string_type((4.5,) * 100, with_min_max=True), "#100(float=4.5,...)[4.5,4.5:A[4.5]]"
)

def test_string_type_at(self):
self.assertEqual(string_type(None), "None")
a = np.array([4, 5], dtype=np.float32)
Expand Down
5 changes: 3 additions & 2 deletions _unittests/ut_xrun_doc/test_onnx_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import onnx.numpy_helper as onh
from onnx import TensorProto
from onnx.checker import check_model
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.onnx_tools import (
onnx_lighten,
onnx_unlighten,
Expand Down Expand Up @@ -72,6 +72,7 @@ def test_onnx_find(self):
self.assertIn("xm2", res[0].output)
self.assertIn("xm2", res[1].input)

@hide_stdout()
def test__validate_function(self):
new_domain = "custom"

Expand All @@ -87,7 +88,7 @@ def test__validate_function(self):
[oh.make_opsetid("", 14)],
[],
)
_validate_function(linear_regression)
_validate_function(linear_regression, verbose=1)


if __name__ == "__main__":
Expand Down
18 changes: 18 additions & 0 deletions _unittests/ut_xrun_doc/test_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ def test_measure_time(self):
},
)

def test_measure_time_max(self):
res = measure_time(lambda: math.cos(0.5), max_time=0.1)
self.assertIsInstance(res, dict)
self.assertEqual(
set(res),
{
"min_exec",
"max_exec",
"average",
"warmup_time",
"context_size",
"deviation",
"repeat",
"ttime",
"number",
},
)


if __name__ == "__main__":
unittest.main(verbosity=2)
6 changes: 1 addition & 5 deletions onnx_diagnostic/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,8 @@ def size_type(dtype: Any) -> int:
return 4
if dtype == np.float16 or dtype == np.int16:
return 2
if dtype == np.int16:
return 2
if dtype == np.int32:
return 4
if dtype == np.int64:
return 8
if dtype == np.int8:
return 1
if hasattr(np, "uint64"):
Expand Down Expand Up @@ -226,7 +222,7 @@ def string_type(
)
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]"
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
return f"#{len(obj)}({tt},...)"
if isinstance(obj, list):
if len(obj) < limit:
Expand Down
Loading