Skip to content

Commit fb81896

Browse files
committed
fix import
1 parent 11434ab commit fb81896

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
check_model_ort,
1616
iterator_initializer_constant,
1717
from_array_extended,
18+
tensor_statistics,
1819
)
1920

2021

2122
TFLOAT = TensorProto.FLOAT
2223

2324

24-
class TestOnnxTools(ExtTestCase):
25+
class TestOnnxHelper(ExtTestCase):
2526

2627
def _get_model(self):
2728
model = oh.make_model(
@@ -242,6 +243,11 @@ def test_iterate_function(self):
242243
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
243244
self.assertIsInstance(li[0][1], torch.Tensor)
244245

246+
def test_statistics(self):
247+
rnd = np.random.rand(40, 50).astype(np.float16)
248+
stat = tensor_statistics(rnd)
249+
print(stat)
250+
245251

246252
if __name__ == "__main__":
247253
unittest.main(verbosity=2)

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def string_type(
427427

428428
# Tensors
429429
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
430-
from .onnx_helper import torch_dtype_to_onnx_dtype
430+
from .torch_helper import torch_dtype_to_onnx_dtype
431431

432432
i = torch_dtype_to_onnx_dtype(obj.dtype)
433433
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,77 @@ def iterator_initializer_constant(
818818
yield from iterator_initializer_constant(
819819
att.g, use_numpy=use_numpy, prefix=f"{prefix}{name}"
820820
)
821+
822+
823+
def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union[float, str]]:
824+
"""
825+
Produces statistics on a tensor.
826+
827+
:param tensor: tensor
828+
:return: statistics
829+
830+
.. runpython::
831+
:showcode:
832+
833+
import pprint
834+
import numpy as np
835+
from onnx_diagnostic.helper.onnx_helper import tensor_statistics
836+
837+
t = np.random.rand(40, 50).astype(np.float16)
838+
pprint.pprint(tensor_statistics(t))
839+
"""
840+
from .helper import size_type
841+
842+
if isinstance(tensor, TensorProto):
843+
tensor = to_array_extended(tensor)
844+
stat = dict(
845+
mean=float(tensor.mean()),
846+
std=float(tensor.std()),
847+
shape="x".join(map(str, tensor.shape)),
848+
numel=tensor.size,
849+
size=tensor.size * size_type(tensor.dtype),
850+
itype=np_dtype_to_tensor_dtype(tensor.dtype),
851+
stype=onnx_dtype_name(np_dtype_to_tensor_dtype(tensor.dtype)),
852+
min=float(tensor.min()),
853+
max=float(tensor.max()),
854+
nnan=np.isnan(tensor).sum(),
855+
)
856+
857+
hist = np.array(
858+
[
859+
0,
860+
1e-10,
861+
1e-8,
862+
1e-7,
863+
1e-6,
864+
1e-5,
865+
0.0001,
866+
0.001,
867+
0.01,
868+
0.1,
869+
0.5,
870+
1,
871+
1.96,
872+
10,
873+
100,
874+
1e3,
875+
1e4,
876+
1e5,
877+
1e6,
878+
1e7,
879+
1e8,
880+
1e10,
881+
1e50,
882+
],
883+
dtype=tensor.dtype,
884+
)
885+
hist = np.array(sorted(set(hist[~np.isinf(hist)])), dtype=tensor.dtype)
886+
ind = np.digitize(np.abs(tensor).reshape((-1,)), hist, right=True)
887+
cou = np.bincount(ind, minlength=ind.shape[0] + 1)
888+
stat.update(
889+
dict(zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))]))
890+
)
891+
ii = (np.arange(9) + 1) / 10
892+
qu = np.quantile(tensor, ii)
893+
stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
894+
return stat

0 commit comments

Comments
 (0)