Skip to content

Commit aceb306

Browse files
committed
study
1 parent 82ed143 commit aceb306

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
3030
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
31-
from onnx_diagnostic.helpers.torch_helper import to_tensor
31+
from onnx_diagnostic.helpers.torch_helper import to_tensor, study_discrepancies
3232

3333
TFLOAT = onnx.TensorProto.FLOAT
3434

@@ -425,6 +425,12 @@ def test_get_weight_type(self):
425425
dt = get_weight_type(model)
426426
self.assertEqual(torch.float32, dt)
427427

428+
def test_study_discrepancies(self):
429+
t1 = torch.rand((3, 4))
430+
t2 = torch.rand((3, 4))
431+
ax = study_discrepancies(t1, t2)
432+
self.assertEqual(ax.shape, ((3, 2)))
433+
428434

429435
if __name__ == "__main__":
430436
unittest.main(verbosity=2)

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def study_discrepancies(
10241024
figsize: Optional[Tuple[int, int]] = (15, 15),
10251025
title: Optional[str] = None,
10261026
name: Optional[str] = None,
1027-
) -> "Axes": # noqa: F821
1027+
) -> "matplotlib.axes.Axes": # noqa: F821
10281028
"""
10291029
Computes different metrics for the discrepancies.
10301030
Returns graphs.

0 commit comments

Comments
 (0)