Skip to content

Commit 69a42c4

Browse files
committed
refactors max_diff
1 parent 5dd7775 commit 69a42c4

File tree

3 files changed

+166
-114
lines changed

3 files changed

+166
-114
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
78
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
89
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
910

_unittests/ut_helpers/test_helper.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,84 @@ def test_max_diff_verbose(self):
203203
d = string_diff(diff)
204204
self.assertIsInstance(d, str)
205205

206+
def test_max_diff_hist_array(self):
207+
x = np.arange(12).reshape((3, 4)).astype(dtype=np.float32)
208+
y = x.copy()
209+
y[0, 1] += 0.1
210+
y[0, 2] += 0.01
211+
y[0, 3] += 0.001
212+
y[1, 1] += 0.0001
213+
y[1, 2] += 1
214+
y[2, 2] += 10
215+
y[1, 3] += 100
216+
y[2, 1] += 1000
217+
diff = max_diff(x, y, hist=True)
218+
self.assertEqual(
219+
diff["rep"],
220+
{
221+
">0.0": 8,
222+
">0.0001": 8,
223+
">0.001": 6,
224+
">0.01": 5,
225+
">0.1": 5,
226+
">1.0": 3,
227+
">10.0": 2,
228+
">100.0": 1,
229+
},
230+
)
231+
232+
def test_max_diff_hist_tensor(self):
233+
x = torch.arange(12).reshape((3, 4)).to(dtype=torch.float32)
234+
y = x.clone()
235+
y[0, 1] += 0.1
236+
y[0, 2] += 0.01
237+
y[0, 3] += 0.001
238+
y[1, 1] += 0.0001
239+
y[1, 2] += 1
240+
y[2, 2] += 10
241+
y[1, 3] += 100
242+
y[2, 1] += 1000
243+
diff = max_diff(x, y, hist=True)
244+
self.assertEqual(
245+
diff["rep"],
246+
{
247+
">0.0": 8,
248+
">0.0001": 8,
249+
">0.001": 6,
250+
">0.01": 5,
251+
">0.1": 5,
252+
">1.0": 3,
253+
">10.0": 2,
254+
">100.0": 1,
255+
},
256+
)
257+
258+
def test_max_diff_hist_tensor_composed(self):
259+
x = torch.arange(12).reshape((3, 4)).to(dtype=torch.float32)
260+
y = x.clone()
261+
y[0, 1] += 0.1
262+
y[0, 2] += 0.01
263+
y[0, 3] += 0.001
264+
y[1, 1] += 0.0001
265+
y[1, 2] += 1
266+
y[2, 2] += 10
267+
y[1, 3] += 100
268+
y[2, 1] += 1000
269+
diff = max_diff([x, (x, {"e": x})], [y, (y, {"e": y})], hist=True)
270+
self.assertEqual(
271+
diff["rep"],
272+
{
273+
">0.0": 24,
274+
">0.0001": 24,
275+
">0.001": 18,
276+
">0.01": 15,
277+
">0.1": 15,
278+
">1.0": 9,
279+
">10.0": 6,
280+
">100.0": 3,
281+
},
282+
)
283+
206284
def test_type_info(self):
207285
for tt in [
208286
onnx.TensorProto.FLOAT,

0 commit comments

Comments
 (0)