Skip to content

Commit 73fba13

Browse files
committed
Add to_tensor
1 parent e151311 commit 73fba13

File tree

19 files changed

+286
-192
lines changed

19 files changed

+286
-192
lines changed

_doc/api/helpers/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ onnx_diagnostic.helpers
1616
onnx_helper
1717
ort_session
1818
rt_helper
19-
torch_test_helper
19+
torch_helper
2020

2121
.. autofunction:: onnx_diagnostic.helpers.flatten_object
2222

_doc/api/helpers/torch_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.torch_helper
3+
====================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.torch_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/torch_test_helper.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_doc/examples/plot_export_tiny_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import transformers
3232
from onnx_diagnostic import doc
3333
from onnx_diagnostic.helpers import string_type
34-
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
34+
from onnx_diagnostic.helpers.torch_helper import steal_forward
3535
from onnx_diagnostic.torch_models.llms import get_tiny_llm
3636

3737

@@ -77,7 +77,7 @@ def _forward_(*args, _f=None, **kwargs):
7777
model.forward = keep_model_forward
7878

7979
# %%
80-
# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`.
80+
# Another syntax with :func:`onnx_diagnostic.helpers.torch_helper.steal_forward`.
8181

8282
with steal_forward(model):
8383
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)

_unittests/ut_export/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
requires_onnxscript,
99
)
1010
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
11-
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
11+
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting
1212

1313
try:
1414
from experimental_experiment.torch_interpreter import to_onnx

_unittests/ut_helpers/test_torch_test_helper.py renamed to _unittests/ut_helpers/test_torch_helper.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import unittest
2+
import numpy as np
23
import ml_dtypes
34
import onnx
45
import torch
56
import transformers
67
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
78
from onnx_diagnostic.helpers import max_diff, string_type
8-
from onnx_diagnostic.helpers.torch_test_helper import (
9+
from onnx_diagnostic.helpers.torch_helper import (
910
dummy_llm,
1011
to_numpy,
1112
is_torchdynamo_exporting,
@@ -24,6 +25,8 @@
2425
make_sliding_window_cache,
2526
)
2627
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
28+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
29+
from onnx_diagnostic.helpers.torch_helper import to_tensor
2730

2831
TFLOAT = onnx.TensorProto.FLOAT
2932

@@ -205,7 +208,7 @@ def forward(self, x, y):
205208
else:
206209
print("output", k, v)
207210
print(string_type(restored, with_shape=True))
208-
l1, l2 = 183, 192
211+
l1, l2 = 186, 195
209212
self.assertEqual(
210213
[
211214
(f"-Model-{l2}", 0, "I"),
@@ -344,6 +347,35 @@ def forward(self, x, y=None):
344347
stat,
345348
)
346349

350+
def test_to_tensor(self):
351+
for dtype in [
352+
np.int8,
353+
np.uint8,
354+
np.int16,
355+
np.uint16,
356+
np.int32,
357+
np.uint32,
358+
np.int64,
359+
np.uint64,
360+
np.float16,
361+
np.float32,
362+
np.float64,
363+
]:
364+
with self.subTest(dtype=dtype):
365+
a = np.random.rand(4, 5).astype(dtype)
366+
proto = from_array_extended(a)
367+
b = to_array_extended(proto)
368+
self.assertEqualArray(a, b)
369+
c = to_tensor(proto)
370+
self.assertEqualArray(a, c)
371+
372+
for dtype in [torch.bfloat16]:
373+
with self.subTest(dtype=dtype):
374+
a = torch.rand((4, 5), dtype=dtype)
375+
proto = from_array_extended(a)
376+
c = to_tensor(proto)
377+
self.assertEqualArray(a, c)
378+
347379

348380
if __name__ == "__main__":
349381
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
33
from onnx_diagnostic.helpers import string_type
4-
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
4+
from onnx_diagnostic.helpers.torch_helper import steal_forward
55

66

77
class TestHuggingFaceHubModel(ExtTestCase):

_unittests/ut_torch_export_patches/test_patch_expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
patched_selector,
88
patched_float_arange,
99
)
10-
from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting
10+
from onnx_diagnostic.helpers.torch_helper import fake_torchdynamo_exporting
1111

1212

1313
class TestOnnxExportErrors(ExtTestCase):

_unittests/ut_torch_export_patches/test_patch_loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
4-
from onnx_diagnostic.helpers.torch_test_helper import (
4+
from onnx_diagnostic.helpers.torch_helper import (
55
is_torchdynamo_exporting,
66
fake_torchdynamo_exporting,
77
)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
1212
torch_export_patches,
1313
)
14-
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy
14+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1515

1616

1717
class TestPatchSerialization(ExtTestCase):

0 commit comments

Comments
 (0)