Skip to content

Commit 600a02f

Browse files
authored
Supports error distribution in max_diff (#92)
* refactors max_diff * updtae string_diff * add model_statistics * add mb * ut
1 parent 5dd7775 commit 600a02f

File tree

6 files changed

+271
-125
lines changed

6 files changed

+271
-125
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`92`: support errors distribution in max_diff
8+
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
79
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
810
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
911

_unittests/ut_helpers/test_helper.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_print_pretty_onnx(self):
127127
)
128128
self.print_onnx(proto)
129129
self.print_model(proto)
130-
self.dump_onnx("test_print_pretty_onnx", proto)
130+
self.dump_onnx("test_print_pretty.onnx", proto)
131131
self.check_ort(proto)
132132
self.assertNotEmpty(proto)
133133
self.assertEmpty(None)
@@ -203,6 +203,101 @@ def test_max_diff_verbose(self):
203203
d = string_diff(diff)
204204
self.assertIsInstance(d, str)
205205

206+
def test_max_diff_hist_array(self):
207+
x = np.arange(12).reshape((3, 4)).astype(dtype=np.float32)
208+
y = x.copy()
209+
y[0, 1] += 0.1
210+
y[0, 2] += 0.01
211+
y[0, 3] += 0.001
212+
y[1, 1] += 0.0001
213+
y[1, 2] += 1
214+
y[2, 2] += 10
215+
y[1, 3] += 100
216+
y[2, 1] += 1000
217+
diff = max_diff(x, y, hist=True)
218+
self.assertEqual(
219+
diff["rep"],
220+
{
221+
">0.0": 8,
222+
">0.0001": 8,
223+
">0.001": 6,
224+
">0.01": 5,
225+
">0.1": 5,
226+
">1.0": 3,
227+
">10.0": 2,
228+
">100.0": 1,
229+
},
230+
)
231+
232+
def test_max_diff_hist_array_string_diff(self):
233+
x = np.arange(12).reshape((3, 4)).astype(dtype=np.float32)
234+
y = x.copy()
235+
y[0, 1] += 0.1
236+
y[0, 2] += 0.01
237+
y[0, 3] += 0.001
238+
y[1, 1] += 0.0001
239+
y[1, 2] += 1
240+
y[2, 2] += 10
241+
y[1, 3] += 100
242+
y[2, 1] += 1000
243+
diff = max_diff(x, y, hist=True)
244+
s = string_diff(diff)
245+
self.assertEndsWith(
246+
"/#8>0.0-#8>0.0001-#6>0.001-#5>0.01-#5>0.1-#3>1.0-#2>10.0-#1>100.0", s
247+
)
248+
249+
def test_max_diff_hist_tensor(self):
250+
x = torch.arange(12).reshape((3, 4)).to(dtype=torch.float32)
251+
y = x.clone()
252+
y[0, 1] += 0.1
253+
y[0, 2] += 0.01
254+
y[0, 3] += 0.001
255+
y[1, 1] += 0.0001
256+
y[1, 2] += 1
257+
y[2, 2] += 10
258+
y[1, 3] += 100
259+
y[2, 1] += 1000
260+
diff = max_diff(x, y, hist=True)
261+
self.assertEqual(
262+
diff["rep"],
263+
{
264+
">0.0": 8,
265+
">0.0001": 8,
266+
">0.001": 6,
267+
">0.01": 5,
268+
">0.1": 5,
269+
">1.0": 3,
270+
">10.0": 2,
271+
">100.0": 1,
272+
},
273+
)
274+
275+
def test_max_diff_hist_tensor_composed(self):
276+
x = torch.arange(12).reshape((3, 4)).to(dtype=torch.float32)
277+
y = x.clone()
278+
y[0, 1] += 0.1
279+
y[0, 2] += 0.01
280+
y[0, 3] += 0.001
281+
y[1, 1] += 0.0001
282+
y[1, 2] += 1
283+
y[2, 2] += 10
284+
y[1, 3] += 100
285+
y[2, 1] += 1000
286+
diff = max_diff([x, (x, {"e": x})], [y, (y, {"e": y})], hist=True)
287+
self.assertEqual(
288+
diff["rep"],
289+
{
290+
">0.0": 24,
291+
">0.0001": 24,
292+
">0.001": 18,
293+
">0.01": 15,
294+
">0.1": 15,
295+
">1.0": 9,
296+
">10.0": 6,
297+
">100.0": 3,
298+
},
299+
)
300+
206301
def test_type_info(self):
207302
for tt in [
208303
onnx.TensorProto.FLOAT,

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
dummy_llm,
1010
to_numpy,
1111
is_torchdynamo_exporting,
12+
model_statistics,
1213
steal_forward,
1314
replace_string_by_dynamic,
1415
to_any,
@@ -172,14 +173,15 @@ def forward(self, x, y):
172173
else:
173174
print("output", k, v)
174175
print(string_type(restored, with_shape=True))
176+
l1, l2 = 151, 160
175177
self.assertEqual(
176178
[
177-
("-Model-159", 0, "I"),
178-
("-Model-159", 0, "O"),
179-
("s1-SubModel-150", 0, "I"),
180-
("s1-SubModel-150", 0, "O"),
181-
("s2-SubModel-150", 0, "I"),
182-
("s2-SubModel-150", 0, "O"),
179+
(f"-Model-{l2}", 0, "I"),
180+
(f"-Model-{l2}", 0, "O"),
181+
(f"s1-SubModel-{l1}", 0, "I"),
182+
(f"s1-SubModel-{l1}", 0, "O"),
183+
(f"s2-SubModel-{l1}", 0, "I"),
184+
(f"s2-SubModel-{l1}", 0, "O"),
183185
],
184186
sorted(restored),
185187
)
@@ -279,6 +281,32 @@ def test_torch_deepcopy_sliding_windon_cache(self):
279281
def test_torch_deepcopy_none(self):
280282
self.assertEmpty(torch_deepcopy(None))
281283

284+
def test_model_statistics(self):
285+
class Model(torch.nn.Module):
286+
def __init__(self):
287+
super().__init__()
288+
self.p1 = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32))
289+
self.b1 = torch.nn.Buffer(torch.tensor([1], dtype=torch.float32))
290+
291+
def forward(self, x, y=None):
292+
return x + y + self.p1 + self.b1
293+
294+
model = Model()
295+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
296+
model(x, y)
297+
stat = model_statistics(model)
298+
self.assertEqual(
299+
{
300+
"type": "Model",
301+
"n_modules": 1,
302+
"param_size": 4,
303+
"buffer_size": 4,
304+
"float32": 8,
305+
"size_mb": 0,
306+
},
307+
stat,
308+
)
309+
282310

283311
if __name__ == "__main__":
284312
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,11 @@ def assertStartsWith(self, prefix: str, full: str):
10701070
if not full.startswith(prefix):
10711071
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
10721072

1073+
def assertEndsWith(self, suffix: str, full: str):
1074+
"""In the name"""
1075+
if not full.endswith(suffix):
1076+
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1077+
10731078
def capture(self, fct: Callable):
10741079
"""
10751080
Runs a function and capture standard output and error.

0 commit comments

Comments
 (0)