@@ -269,18 +269,24 @@ def get_replay_code(self) -> str:
269269 print(f"-- expected={string_type(expected, **skws)}")
270270 print(f"-- mapping={mapping}")
271271
272- model = onnx.load("model.onnx" )
272+ print( )
273273 print("-- model.onnx")
274+ print()
275+
276+ model = onnx.load("model.onnx")
274277 print(pretty_onnx(model))
275- print("--")
276278
277- print("-- range of inputs")
279+ print()
280+ print("-- range of inputs --")
281+ print()
282+
278283 for k, v in onnx_inputs.items():
279284 print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}")
280- print("-- done.")
281- print("--")
282285
283- print("-- discrepancies of inputs")
286+ print()
287+ print("-- discrepancies of inputs --")
288+ print()
289+
284290 ep_feeds = {}
285291 for k, v in onnx_inputs.items():
286292 tk = mapping.get(k, k)
@@ -291,24 +297,29 @@ def get_replay_code(self) -> str:
291297 f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} "
292298 f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}"
293299 )
294- print("-- done.")
295- print("--")
296300
297- print("-- run with onnx_inputs")
301+ print()
302+ print("-- SVD --")
303+ print()
304+
305+ for k, v in onnx_inputs.items():
306+ if len(v.shape) == 2:
307+ U, S, Vt = torch.linalg.svd(v.to(torch.float32))
308+ print(f" -- {k}: {S[:5]}")
309+
310+ print()
311+ print("-- run with onnx_inputs --")
312+ print()
313+
298314 sess = OnnxruntimeEvaluator(model, whole=True)
299315 feeds = onnx_inputs
300316 obtained = sess.run(None, feeds)
301317 print(f"-- obtained={string_type(obtained, **skws)}")
302318 diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
303319 print(f"-- diff: {string_diff(diff)}")
304- print("--")
305- print("-- run with torch_inputs")
306- obtained = sess.run(None, ep_feeds)
307- print(f"-- obtained={string_type(obtained, **skws)}")
308- diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
309- print(f"-- diff: {string_diff(diff)}")
320+ print()
321+ print("-- plots --")
310322
311- print("-- plots")
312323 for i in range(len(expected)):
313324 study_discrepancies(
314325 expected[i],
@@ -317,6 +328,19 @@ def get_replay_code(self) -> str:
317328 name=f"disc{i}.png",
318329 bins=50,
319330 )
331+
332+ print()
333+ print("-- run with torch_inputs --")
334+ print()
335+
336+ obtained = sess.run(None, ep_feeds)
337+ print(f"-- obtained={string_type(obtained, **skws)}")
338+ diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
339+ print(f"-- diff: {string_diff(diff)}")
340+
341+ print()
342+ print("-- end --")
343+ print()
320344 """
321345 )
322346
0 commit comments