Skip to content

Commit 4d93786

Browse files
authored
Implements validate_ep (#46)
* Implements validate_ep * mypy * req * union * documentation
1 parent c2c98cc commit 4d93786

File tree

8 files changed

+302
-12
lines changed

8 files changed

+302
-12
lines changed

_doc/api/export/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.export
66
:caption: modules
77

88
dynamic_shapes
9+
validate
910

1011
CoupleInputsDynamicShapes
1112
+++++++++++++++++++++++++
@@ -19,6 +20,11 @@ ModelInputs
1920
.. autoclass:: onnx_diagnostic.export.ModelInputs
2021
:members:
2122

23+
validate_ep
24+
+++++++++++
25+
26+
.. autofunction:: onnx_diagnostic.export.validate_ep
27+
2228
Other functions
2329
+++++++++++++++
2430

_doc/api/export/validate.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.export.validate
3+
===============================
4+
5+
.. automodule:: onnx_diagnostic.export.validate
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: validate_ep
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.export import CoupleInputsDynamicShapes, validate_ep
5+
6+
7+
class TestValidate(ExtTestCase):
8+
@hide_stdout()
9+
def test_validate_args(self):
10+
class Model(torch.nn.Module):
11+
def forward(self, x, y):
12+
return x + y
13+
14+
model = Model()
15+
x = torch.randn((5, 6))
16+
y = torch.randn((1, 6))
17+
model(x, y)
18+
ds = ({0: "a", 1: "b"}, {1: "b"})
19+
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
20+
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
21+
validate_ep(
22+
ep,
23+
model,
24+
args=(x, y),
25+
verbose=2,
26+
copy=True,
27+
dynamic_shapes=ds,
28+
values_to_try={"a": [5, 10], "b": [10, 20]},
29+
)
30+
31+
@hide_stdout()
32+
def test_validate_kwargs(self):
33+
class Model(torch.nn.Module):
34+
def forward(self, x, y):
35+
return x + y
36+
37+
model = Model()
38+
x = torch.randn((5, 6))
39+
y = torch.randn((1, 6))
40+
model(x=x, y=y)
41+
ds = dict(x={0: "a", 1: "b"}, y={1: "b"})
42+
cpl = CoupleInputsDynamicShapes((), dict(x=x, y=y), ds)
43+
ep = torch.export.export(
44+
model, (), kwargs=dict(x=x, y=y), dynamic_shapes=cpl.replace_string_by()
45+
)
46+
validate_ep(
47+
ep,
48+
model,
49+
kwargs=dict(x=x, y=y),
50+
verbose=2,
51+
copy=True,
52+
dynamic_shapes=ds,
53+
values_to_try={"a": [5, 10], "b": [10, 20]},
54+
)
55+
56+
@hide_stdout()
57+
def test_validate_args_kwargs(self):
58+
class Model(torch.nn.Module):
59+
def forward(self, x, y):
60+
return x + y
61+
62+
model = Model()
63+
x = torch.randn((5, 6))
64+
y = torch.randn((1, 6))
65+
model(x, y=y)
66+
ds = dict(x={0: "a", 1: "b"}, y={1: "b"})
67+
cpl = CoupleInputsDynamicShapes((x,), dict(y=y), ds, args_names=["x"])
68+
ep = torch.export.export(
69+
model, (x,), kwargs=dict(y=y), dynamic_shapes=cpl.replace_string_by()
70+
)
71+
validate_ep(
72+
ep,
73+
model,
74+
args=(x,),
75+
kwargs=dict(y=y),
76+
verbose=2,
77+
copy=True,
78+
dynamic_shapes=ds,
79+
values_to_try={"a": [5, 10], "b": [10, 20]},
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
unittest.main(verbosity=2)

onnx_diagnostic/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs
2+
from .validate import validate_ep

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _valid_shapes_tensor(cls, inputs, ds):
147147
issues[i] = f"d=[{d}]"
148148
return issues if issues else None
149149

150-
def _generic_walker(self, processor: Callable):
150+
def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
151151
"""
152152
Generic deserializator walking through inputs and dynamic_shapes all along.
153153
The function returns a result with the same structure as the dynamic shapes.
@@ -157,14 +157,16 @@ def _generic_walker(self, processor: Callable):
157157
f"Type mismatch, args={string_type(self.args)} and "
158158
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
159159
)
160-
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
160+
res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
161+
return (tuple(), res) if args_kwargs else res
161162

162163
if not self.kwargs:
163164
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
164165
f"Type mismatch, args={string_type(self.args)} and "
165166
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
166167
)
167-
return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
168+
res = self._generic_walker_step(processor, self.args, self.dynamic_shapes)
169+
return (res, {}) if args_kwargs else res
168170

169171
assert isinstance(self.dynamic_shapes, dict), (
170172
f"Both positional and named arguments (args and kwargs) are filled. "
@@ -192,7 +194,17 @@ def _generic_walker(self, processor: Callable):
192194
)
193195
kwargs = dict(zip(self.args_names, self.args))
194196
kwargs.update(self.kwargs)
195-
return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
197+
res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
198+
if args_kwargs:
199+
pgs = [None for _ in range(len(self.args))]
200+
kws = {}
201+
for k, v in res.items():
202+
if k not in self.kwargs:
203+
pgs[self.args_names.index(k)] = v
204+
else:
205+
kws[k] = v
206+
return pgs, kws
207+
return res
196208

197209
raise NotImplementedError(
198210
f"Not yet implemented when args is filled, "
@@ -285,14 +297,14 @@ def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
285297
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
286298
)
287299
mind = min(d0, d1)
288-
indices = [slice(None) for _ in range(rank)]
300+
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
289301
indices[i] = slice(0, mind)
290302
ind = tuple(indices)
291303
new_tensor[ind] = tensor[ind]
292304
if d1 > mind:
293305
for k in range(d1 - mind):
294-
indices0 = [slice(None) for _ in range(rank)]
295-
indices1 = [slice(None) for _ in range(rank)]
306+
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
307+
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
296308
indices1[i] = mind + k
297309
indices0[i] = k % mind
298310
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
@@ -310,7 +322,9 @@ def __call__(self, inputs, ds):
310322
new_shape = self._build_new_shape(inputs.shape, ds)
311323
return self._build_new_tensor(inputs, new_shape)
312324

313-
def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = None):
325+
def change_dynamic_dimensions(
326+
self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
327+
):
314328
"""
315329
A model exported with dynamic shapes is not necessarily dynamic
316330
just because the user specified dynamic shapes. The algorithm
@@ -321,6 +335,7 @@ def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = N
321335
the model.
322336
323337
:param desired_values: to fixed named dimension to have the desired value
338+
:param args_kwargs: return both args, kwargs even if empty
324339
:return: new inputs
325340
326341
Example:
@@ -343,7 +358,9 @@ def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = N
343358
print("before:", string_type(kwargs, with_shape=True))
344359
print("-after:", string_type(new_kwargs, with_shape=True))
345360
"""
346-
return self._generic_walker(self.ChangeDimensionProcessor(desired_values))
361+
return self._generic_walker(
362+
self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
363+
)
347364

348365

349366
class ModelInputs:

onnx_diagnostic/export/validate.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import inspect
2+
import itertools
3+
import time
4+
from typing import Any, Dict, List, Optional, Tuple, Union
5+
import torch
6+
from ..helpers import string_type, max_diff, string_diff
7+
from ..helpers.torch_test_helper import torch_deepcopy
8+
from .dynamic_shapes import CoupleInputsDynamicShapes
9+
10+
11+
def compare_modules(
12+
modep: torch.nn.Module,
13+
mod: Optional[torch.nn.Module] = None,
14+
args: Optional[Tuple[Any, ...]] = None,
15+
kwargs: Optional[Dict[str, Any]] = None,
16+
copy: bool = False,
17+
exc: bool = True,
18+
verbose: int = 0,
19+
atol: float = 1e-2,
20+
rtol: float = 1e-1,
21+
) -> Dict[str, Any]:
22+
"""
23+
Compares two torch modules, usually one coming from an exported program,
24+
the other being the origin model.
25+
26+
:param model: first module
27+
:param mod: second module (it produces the expected values)
28+
:param args: positional arguments
29+
:param kwargs: named arguments
30+
:param copy: copy the inputs before executing the model (they may modify them inplace)
31+
:param exc: raise exception if discrepancies are too high
32+
:param verbose: verbosity level
33+
:param atol: absolute tolerance
34+
:param rtol: relative tolerance
35+
:return: dictionary with inputs, outputs and tolerance
36+
37+
Example:
38+
39+
.. runpython::
40+
:showcode:
41+
42+
import torch
43+
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
44+
45+
class Model(torch.nn.Module):
46+
def forward(self, x, y):
47+
return x + y
48+
49+
model = Model()
50+
x = torch.randn((5, 6))
51+
y = torch.randn((1, 6))
52+
model(x, y) # to make it is running
53+
54+
ds = ({0: "a", 1: "b"}, {1: "b"})
55+
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
56+
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
57+
validate_ep(
58+
ep,
59+
model,
60+
args=(x, y),
61+
verbose=2,
62+
copy=True,
63+
dynamic_shapes=ds,
64+
values_to_try={"a": [5, 10], "b": [10, 20]},
65+
)
66+
67+
"""
68+
args = args or ()
69+
kwargs = kwargs or {}
70+
71+
def _get(a):
72+
return torch_deepcopy(a) if copy else a
73+
74+
if verbose:
75+
begin = time.perf_counter()
76+
print(
77+
f"[compare_modules] check ep with "
78+
f"args={string_type(args, with_shape=True)}, "
79+
f"kwargs={string_type(kwargs, with_shape=True)}..."
80+
)
81+
got = modep(*_get(args), **_get(kwargs))
82+
if verbose:
83+
d = time.perf_counter() - begin
84+
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85+
if mod:
86+
if verbose:
87+
begin = time.perf_counter()
88+
print("[compare_modules] run torch module...")
89+
expected = mod(*_get(args), **_get(kwargs))
90+
diff = max_diff(expected, got)
91+
if verbose:
92+
d = time.perf_counter() - begin
93+
print(
94+
f"[compare_modules] done in {d} with "
95+
f"output={string_type(expected, with_shape=True)}"
96+
)
97+
print(f"[compare_modules] discrepancies={string_diff(diff)}")
98+
assert not exc or (
99+
diff["abs"] <= atol and diff["rel"] <= rtol
100+
), f"Discrepancies={string_diff(diff)} higher than expected."
101+
return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
102+
return dict(args=args, kwargs=kwargs, got=got)
103+
104+
105+
def validate_ep(
106+
ep: Union[torch.nn.Module, torch.export.ExportedProgram],
107+
mod: Optional[torch.nn.Module] = None,
108+
args: Optional[Tuple[Any, ...]] = None,
109+
kwargs: Optional[Dict[str, Any]] = None,
110+
copy: bool = False,
111+
dynamic_shapes: Optional[Any] = None,
112+
values_to_try: Optional[Dict[str, List[int]]] = None,
113+
exc: bool = True,
114+
verbose: int = 0,
115+
atol: float = 1e-2,
116+
rtol: float = 1e-1,
117+
) -> List[Dict[str, Any]]:
118+
"""
119+
Validates an exported program.
120+
121+
:param model: first module
122+
:param mod: second module (it produces the expected values)
123+
:param args: positional arguments
124+
:param kwargs: named arguments
125+
:param copy: copy the inputs before executing the model (they may modify them inplace)
126+
:param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
127+
:param values_to_try: dictionary with the values to try for every dynamic dimension
128+
:param exc: raise exception if discrepancies are too high
129+
:param verbose: verbosity level
130+
:param atol: absolute tolerance
131+
:param rtol: relative tolerance
132+
:return: dictionary with inputs, outputs and tolerance
133+
"""
134+
modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
135+
136+
results = [
137+
compare_modules(
138+
modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
139+
)
140+
]
141+
142+
assert (dynamic_shapes and values_to_try) or (
143+
not dynamic_shapes and not values_to_try
144+
), "Either both dynamic_shapes and values_to_try are specified, either none."
145+
if not dynamic_shapes or not values_to_try:
146+
return results
147+
148+
items = list(values_to_try.items())
149+
keys = [_[0] for _ in items]
150+
values = [_[1] for _ in items]
151+
all_vals = list(itertools.product(*values))
152+
cpl = CoupleInputsDynamicShapes(
153+
args or (),
154+
kwargs or {},
155+
dynamic_shapes,
156+
args_names=(
157+
list(inspect.signature(modep.forward).parameters) if args and kwargs else None
158+
),
159+
)
160+
for i, vals in enumerate(all_vals):
161+
change_dims = dict(zip(keys, vals))
162+
if verbose:
163+
print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
164+
new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
165+
na, nkw = new_params
166+
c = compare_modules(
167+
modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
168+
)
169+
results.append(c)
170+
return results

0 commit comments

Comments
 (0)