Skip to content

Commit 82ed143

Browse files
committed
fix
1 parent 1ab2efb commit 82ed143

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,9 @@ def study_discrepancies(
10491049
fig, ax = plt.subplots(3, 2, figsize=figsize)
10501050
vmin, vmax = d1.min().item(), d1.max().item()
10511051
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1052-
ax[0, 0].set_title(f"Color plot of the first tensor in\n[{vmin}, {vmax}]")
1052+
ax[0, 0].set_title(
1053+
f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
1054+
)
10531055

10541056
diff = d2 - d1
10551057
vmin, vmax = diff.min().item(), diff.max().item()

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,24 @@ def get_replay_code(self) -> str:
269269
print(f"-- expected={string_type(expected, **skws)}")
270270
print(f"-- mapping={mapping}")
271271
272-
model = onnx.load("model.onnx")
272+
print()
273273
print("-- model.onnx")
274+
print()
275+
276+
model = onnx.load("model.onnx")
274277
print(pretty_onnx(model))
275-
print("--")
276278
277-
print("-- range of inputs")
279+
print()
280+
print("-- range of inputs --")
281+
print()
282+
278283
for k, v in onnx_inputs.items():
279284
print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}")
280-
print("-- done.")
281-
print("--")
282285
283-
print("-- discrepancies of inputs")
286+
print()
287+
print("-- discrepancies of inputs --")
288+
print()
289+
284290
ep_feeds = {}
285291
for k, v in onnx_inputs.items():
286292
tk = mapping.get(k, k)
@@ -291,24 +297,29 @@ def get_replay_code(self) -> str:
291297
f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} "
292298
f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}"
293299
)
294-
print("-- done.")
295-
print("--")
296300
297-
print("-- run with onnx_inputs")
301+
print()
302+
print("-- SVD --")
303+
print()
304+
305+
for k, v in onnx_inputs.items():
306+
if len(v.shape) == 2:
307+
U, S, Vt = torch.linalg.svd(v.to(torch.float32))
308+
print(f" -- {k}: {S[:5]}")
309+
310+
print()
311+
print("-- run with onnx_inputs --")
312+
print()
313+
298314
sess = OnnxruntimeEvaluator(model, whole=True)
299315
feeds = onnx_inputs
300316
obtained = sess.run(None, feeds)
301317
print(f"-- obtained={string_type(obtained, **skws)}")
302318
diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
303319
print(f"-- diff: {string_diff(diff)}")
304-
print("--")
305-
print("-- run with torch_inputs")
306-
obtained = sess.run(None, ep_feeds)
307-
print(f"-- obtained={string_type(obtained, **skws)}")
308-
diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
309-
print(f"-- diff: {string_diff(diff)}")
320+
print()
321+
print("-- plots --")
310322
311-
print("-- plots")
312323
for i in range(len(expected)):
313324
study_discrepancies(
314325
expected[i],
@@ -317,6 +328,19 @@ def get_replay_code(self) -> str:
317328
name=f"disc{i}.png",
318329
bins=50,
319330
)
331+
332+
print()
333+
print("-- run with torch_inputs --")
334+
print()
335+
336+
obtained = sess.run(None, ep_feeds)
337+
print(f"-- obtained={string_type(obtained, **skws)}")
338+
diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
339+
print(f"-- diff: {string_diff(diff)}")
340+
341+
print()
342+
print("-- end --")
343+
print()
320344
"""
321345
)
322346

0 commit comments

Comments
 (0)