Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.5.0
+++++

* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)

0.4.4
Expand Down
1 change: 1 addition & 0 deletions _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ onnx_diagnostic.helpers
config_helper
helper
memory_peak
mini_onnx_builder
onnx_helper
ort_session
rt_helper
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/helpers/mini_onnx_builder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.helpers.mini_onnx_builder
=========================================

.. automodule:: onnx_diagnostic.helpers.mini_onnx_builder
:members:
:no-undoc-members:
156 changes: 156 additions & 0 deletions _unittests/ut_helpers/test_mini_onnx_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import unittest
import numpy as np
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
from onnx_diagnostic.helpers.mini_onnx_builder import (
create_onnx_model_from_input_tensors,
create_input_tensors_from_onnx_model,
MiniOnnxBuilder,
)
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.helpers import string_type


class TestMiniOnnxBuilder(ExtTestCase):
def test_mini_onnx_builder_sequence_onnx(self):
builder = MiniOnnxBuilder()
builder.append_output_sequence("name", [np.array([6, 7])])
onx = builder.to_onnx()
ref = ExtendedReferenceEvaluator(onx)
got = ref.run(None, {})
self.assertEqualAny([np.array([6, 7])], got[0])

def test_mini_onnx_builder_sequence_ort(self):
from onnxruntime import InferenceSession

builder = MiniOnnxBuilder()
builder.append_output_sequence("name", [np.array([6, 7])])
onx = builder.to_onnx()
ref = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
got = ref.run(None, {})
self.assertEqualAny([np.array([6, 7])], got[0])

def test_mini_onnx_builder(self):
data = [
(
np.array([1, 2], dtype=np.int64),
torch.tensor([4, 5], dtype=torch.float32),
{
"tt1": np.array([-1, -2], dtype=np.int64),
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
},
{},
),
{
"t1": np.array([1, 2], dtype=np.int64),
"t2": torch.tensor([4, 5], dtype=torch.float32),
"d1": {
"tt1": np.array([-1, -2], dtype=np.int64),
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
},
"d2": {},
},
(
np.array([1, 2], dtype=np.int64),
torch.tensor([4, 5], dtype=torch.float32),
(
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
),
tuple(),
),
{
"t1": np.array([1, 2], dtype=np.int64),
"t2": torch.tensor([4, 5], dtype=torch.float32),
"l1": (
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
),
"l2": tuple(),
},
# nested
(
{
"t1": np.array([1, 2], dtype=np.int64),
"t2": torch.tensor([4, 5], dtype=torch.float32),
"l1": (
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
),
"l2": tuple(),
},
(
np.array([1, 2], dtype=np.int64),
torch.tensor([4, 5], dtype=torch.float32),
(
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
),
tuple(),
),
),
# simple
np.array([1, 2], dtype=np.int64),
torch.tensor([4, 5], dtype=torch.float32),
(np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)),
[np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)],
{
"t1": np.array([1, 2], dtype=np.int64),
"t2": torch.tensor([4, 5], dtype=torch.float32),
},
(
np.array([1, 2], dtype=np.int64),
torch.tensor([4, 5], dtype=torch.float32),
[
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
],
[],
),
{
"t1": np.array([1, 2], dtype=np.int64),
"t2": torch.tensor([4, 5], dtype=torch.float32),
"l1": [
np.array([-1, -2], dtype=np.int64),
torch.tensor([-4, -5], dtype=torch.float32),
],
"l2": [],
},
]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder_transformers(self):
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
self.assertEqual(len(cache.key_cache), 1)
self.assertEqual(len(cache.value_cache), 1)

data = [(cache,), cache]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder_transformers_sep(self):
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
self.assertEqual(len(cache.key_cache), 1)
self.assertEqual(len(cache.value_cache), 1)

data = [(cache,), cache]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs, sep="#")
restored = create_input_tensors_from_onnx_model(model, sep="#")
self.assertEqualAny(inputs, restored)


if __name__ == "__main__":
unittest.main(verbosity=2)
85 changes: 85 additions & 0 deletions _unittests/ut_helpers/test_torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
make_mamba_cache,
make_sliding_window_cache,
)
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model

TFLOAT = onnx.TensorProto.FLOAT

Expand Down Expand Up @@ -59,6 +60,90 @@ def forward(self, x, y):
with steal_forward(model):
model(*inputs)

@hide_stdout()
def test_steal_forward_multi(self):
class SubModel(torch.nn.Module):
def forward(self, x):
return x * x

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()

def forward(self, x, y):
return self.s1(x) + self.s2(y)

inputs = torch.rand(3, 4), torch.rand(3, 4)
model = Model()
with steal_forward(
[
(
"main",
model,
),
(" s1", model.s1),
(" s2", model.s2),
]
):
model(*inputs)

@hide_stdout()
def test_steal_forward_dump_file(self):
class SubModel(torch.nn.Module):
def forward(self, x):
return x * x

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()

def forward(self, x, y):
return self.s1(x) + self.s2(y)

inputs = torch.rand(3, 4), torch.rand(3, 4)
model = Model()
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
with steal_forward(
[
(
"main",
model,
),
(" s1", model.s1),
(" s2", model.s2),
],
dump_file=dump_file,
):
res1 = model(*inputs)
res2 = model(*inputs)
self.assertExists(dump_file)
restored = create_input_tensors_from_onnx_model(dump_file)
self.assertEqual(
[
("main", 0, "I"),
("main", 0, "O"),
("main", 1, "I"),
("main", 1, "O"),
("s1", 0, "I"),
("s1", 0, "O"),
("s1", 1, "I"),
("s1", 1, "O"),
("s2", 0, "I"),
("s2", 0, "O"),
("s2", 1, "I"),
("s2", 1, "O"),
],
sorted(restored),
)
self.assertEqualAny(restored["main", 0, "I"], (inputs, {}))
self.assertEqualAny(restored["main", 1, "I"], (inputs, {}))
self.assertEqualAny(restored["main", 0, "O"], res1)
self.assertEqualAny(restored["main", 0, "O"], res2)

def test_replace_string_by_dynamic(self):
example = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
Expand Down
2 changes: 0 additions & 2 deletions _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ def test_phi2_export_interpreter(self):
)
# ep = ep.run_decompositions()

# from experimental_experiment.torch_interpreter.tracing import CustomTracer
# CustomTracer.remove_unnecessary_slices(ep.graph)
memorize = []

class MyInterpreter(torch.fx.Interpreter):
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/memory_peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def start_spying_on(

.. code-block:: python

from experimental_experiment.memory_peak import start_spying_on, flatten
from onnx_diagnostic.helpers.memory_peak import start_spying_on, flatten

p = start_spying_on()
# ...
Expand Down
Loading
Loading