Skip to content

Commit 641ff8c

Browse files
authored
Changes Cache serialization (#277)
* Changes Cache serialization * mypy * fix * other fixes * fix other tests * fix modelbuilder * disable two ewemples * fix some issues * fix caches * more tests * fix version * fix issues * mypy * import * fix issues
1 parent 7979496 commit 641ff8c

36 files changed

+984
-448
lines changed

_doc/examples/plot_export_hub_codellama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
from onnx_diagnostic import doc
2323
from onnx_diagnostic.ext_test_case import unit_test_going
2424
from onnx_diagnostic.helpers import string_type
25-
from onnx_diagnostic.torch_models.hghub import (
26-
get_untrained_model_with_inputs,
27-
)
25+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2826
from onnx_diagnostic.torch_models.hghub.hub_api import (
2927
get_model_info,
3028
get_pretrained_config,

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
from onnx_diagnostic.helpers.rt_helper import make_feeds
3434
from onnx_diagnostic.torch_export_patches import torch_export_patches
3535
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
36-
from onnx_diagnostic.torch_models.hghub import (
37-
get_untrained_model_with_inputs,
38-
)
36+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
3937

4038
warnings.simplefilter("ignore")
4139

_doc/technical/plot_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def simple_generate_with_cache(
155155
dtype = get_weight_type(model)
156156
print("-- model dtype:", dtype)
157157
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
158-
exporter = "custom" if "custom" in sys.argv else "onnx-dynamo"
158+
exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom"
159159
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
160160
if not os.path.exists(model_name):
161161
# This step is slow so let's skip it if it was already done.

_unittests/ut_export/test_api.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
4+
from onnx_diagnostic.helpers import max_diff
5+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
6+
from onnx_diagnostic.helpers.rt_helper import make_feeds
7+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
8+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
9+
from onnx_diagnostic.torch_export_patches import torch_export_patches
410
from onnx_diagnostic.export.api import to_onnx
511

612

@@ -19,16 +25,80 @@ def forward(self, x, y):
1925
(x, y),
2026
dynamic_shapes=ds,
2127
exporter="custom",
22-
filename=self.get_dump_file("custom.onnx"),
28+
filename=self.get_dump_file("to_onnx_custom.onnx"),
2329
)
2430
to_onnx(
2531
Model(),
2632
(x, y),
2733
dynamic_shapes=ds,
2834
exporter="onnx-dynamo",
29-
filename=self.get_dump_file("onnx-dynamo.onnx"),
35+
filename=self.get_dump_file("to_onnx_onnx-dynamo.onnx"),
3036
)
3137

38+
@hide_stdout()
39+
def test_tiny_llm_to_onnx(self):
40+
import onnxruntime
41+
42+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
43+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
44+
b1 = data["inputs_batch1"]
45+
filenames = {
46+
"custom": self.get_dump_file("test_tiny_llm_to_onnx-custom.onnx"),
47+
"onnx-dynamo": self.get_dump_file("test_tiny_llm_to_onnx-dynamo.onnx"),
48+
"modelbuilder": self.get_dump_file("model.onnx"),
49+
}
50+
if not has_transformers("4.55"):
51+
# <4.55: torch._check(causal_mask.shape[3] != 33)
52+
# torch._check(causal_mask.shape[3] == 33)
53+
del filenames["onnx-dynamo"]
54+
del inputs["position_ids"]
55+
del ds["position_ids"]
56+
del b1["position_ids"]
57+
58+
expected = model(**torch_deepcopy(b1))
59+
60+
with torch_export_patches(patch_transformers=True):
61+
for exporter, filename in filenames.items():
62+
with self.subTest(exporter=exporter):
63+
to_onnx(
64+
model,
65+
kwargs=inputs,
66+
dynamic_shapes=ds,
67+
exporter=exporter,
68+
filename=filename,
69+
)
70+
for exporter, filename in filenames.items():
71+
with self.subTest(exporter=f"validate-{exporter}"):
72+
sess = onnxruntime.InferenceSession(
73+
filename, providers=["CPUExecutionProvider"]
74+
)
75+
feeds = make_feeds(sess, b1, use_numpy=True)
76+
got = sess.run(None, feeds)
77+
diff = max_diff(expected, got)
78+
assert diff["abs"] <= 1e-5, f"diff={diff}"
79+
80+
problem = dict(
81+
input_ids=torch.tensor([[24320]], dtype=torch.int64),
82+
attention_mask=torch.tensor([[1, 1, 1, 1]], dtype=torch.int64),
83+
past_key_values=make_dynamic_cache(
84+
[
85+
torch.rand((1, 1, 3, 96), dtype=torch.float32),
86+
torch.rand((1, 1, 3, 96), dtype=torch.float32),
87+
]
88+
),
89+
)
90+
91+
expected = model(**torch_deepcopy(problem))
92+
for exporter, filename in filenames.items():
93+
with self.subTest(exporter=f"full-mask-{exporter}"):
94+
sess = onnxruntime.InferenceSession(
95+
filename, providers=["CPUExecutionProvider"]
96+
)
97+
feeds = make_feeds(sess, problem, use_numpy=True)
98+
got = sess.run(None, feeds)
99+
diff = max_diff(expected, got)
100+
assert diff["abs"] <= 1e-5, f"diff={diff}"
101+
32102

33103
if __name__ == "__main__":
34104
unittest.main(verbosity=2)

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,18 @@ def forward(self, cache, z):
452452
(
453453
(
454454
[
455-
[{}, {}],
456-
[
457-
{
458-
0: torch.export.Dim.DYNAMIC,
459-
2: torch.export.Dim.DYNAMIC,
460-
3: torch.export.Dim.DYNAMIC,
461-
},
462-
{
463-
0: torch.export.Dim.DYNAMIC,
464-
2: torch.export.Dim.DYNAMIC,
465-
3: torch.export.Dim.DYNAMIC,
466-
},
467-
],
455+
{},
456+
{
457+
0: torch.export.Dim.DYNAMIC,
458+
2: torch.export.Dim.DYNAMIC,
459+
3: torch.export.Dim.DYNAMIC,
460+
},
461+
{},
462+
{
463+
0: torch.export.Dim.DYNAMIC,
464+
2: torch.export.Dim.DYNAMIC,
465+
3: torch.export.Dim.DYNAMIC,
466+
},
468467
],
469468
{3: torch.export.Dim.DYNAMIC},
470469
),
@@ -520,11 +519,10 @@ def forward(self, cache, z):
520519
(
521520
(
522521
[
523-
[{}, {}],
524-
[
525-
{0: "dim_0I_1o_0l0", 2: "dim_0I_1o_0l2", 3: "dim_0I_1o_0l3"},
526-
{0: "dim_0I_1o_1l0", 2: "dim_0I_1o_1l2", 3: "dim_0I_1o_1l3"},
527-
],
522+
{},
523+
{0: "dim_0I_1o0", 2: "dim_0I_1o2", 3: "dim_0I_1o3"},
524+
{},
525+
{0: "dim_0I_3o0", 2: "dim_0I_3o2", 3: "dim_0I_3o3"},
528526
],
529527
{3: "dim_1I3"},
530528
),
@@ -641,18 +639,18 @@ def test_couple_input_ds_cache(self):
641639
kwargs,
642640
{
643641
"A": ds_batch,
644-
"B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]),
642+
"B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch]),
645643
},
646644
).invalid_dimensions_for_export(),
647645
)
648646
self.assertEqual(
649-
{"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])},
647+
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
650648
Cls(
651649
(),
652650
kwargs,
653651
{
654652
"A": ds_batch,
655-
"B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]),
653+
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
656654
},
657655
).invalid_dimensions_for_export(),
658656
)
@@ -831,18 +829,17 @@ def test_dynamic_cache_replace_by_string(self):
831829

832830
DYN = torch.export.Dim.DYNAMIC
833831
ds = {
834-
"cache": [
835-
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
836-
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
837-
]
832+
"cache": [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}]
838833
}
839834
inst = CoupleInputsDynamicShapes((), dict(cache=cache), ds)
840835
as_string = inst.replace_by_string()
841836
self.assertEqual(
842837
{
843838
"cache": [
844-
[{0: "Dim0", 1: "Dim1"}, {0: "Dim2", 1: "Dim3"}],
845-
[{0: "Dim4", 1: "Dim5"}, {0: "Dim6", 1: "Dim7"}],
839+
{0: "Dim0", 1: "Dim1"},
840+
{0: "Dim2", 1: "Dim3"},
841+
{0: "Dim4", 1: "Dim5"},
842+
{0: "Dim6", 1: "Dim7"},
846843
]
847844
},
848845
as_string,
@@ -865,6 +862,81 @@ def test_unbatch_inputs(self):
865862
s,
866863
)
867864

865+
def test_guess_dynamic_cache_without_patches(self):
866+
n_layers = 2
867+
bsize, nheads, slen, dim = 2, 4, 3, 7
868+
cache = make_dynamic_cache(
869+
[
870+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
871+
for i in range(n_layers)
872+
]
873+
)
874+
z = torch.randn((1, 1, 1, 7))
875+
cache2 = make_dynamic_cache(
876+
[
877+
(
878+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
879+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
880+
)
881+
for i in range(n_layers)
882+
]
883+
)
884+
inputs = [
885+
(cache, z),
886+
(cache2, torch.randn((1, 1, 1, 8))),
887+
]
888+
889+
class Model(torch.nn.Module):
890+
def forward(self, cache, z):
891+
cache = CacheKeyValue(cache)
892+
return (
893+
z
894+
+ cache.key_cache[0]
895+
+ cache.key_cache[1]
896+
+ cache.value_cache[0]
897+
+ cache.value_cache[1]
898+
)
899+
900+
mi = ModelInputs(Model(), inputs)
901+
ds = mi.guess_dynamic_shapes()
902+
DYN = torch.export.Dim.DYNAMIC
903+
self.assertEqual(
904+
(
905+
(
906+
[
907+
{0: DYN, 2: DYN, 3: DYN},
908+
{0: DYN, 2: DYN, 3: DYN},
909+
{0: DYN, 2: DYN, 3: DYN},
910+
{0: DYN, 2: DYN, 3: DYN},
911+
],
912+
{3: DYN},
913+
),
914+
{},
915+
),
916+
ds,
917+
)
918+
919+
def test_invalid_dimensions_for_export(self):
920+
ags = []
921+
kws = dict(
922+
input_ids=torch.randint(0, 10, (2, 3)),
923+
attention_mask=torch.randint(0, 1, (2, 33)),
924+
position_ids=torch.randint(0, 10, (2, 3)),
925+
past_key_values=make_dynamic_cache(
926+
[torch.rand((2, 1, 30, 96)), torch.rand((2, 1, 30, 96))]
927+
),
928+
)
929+
ds = dict(
930+
input_ids={0: "batch", 1: "seq_length"},
931+
attention_mask={0: "batch", 1: "seq_length"},
932+
position_ids={0: "batch", 1: "seq_length"},
933+
past_key_values=[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
934+
)
935+
with torch_export_patches(patch_transformers=True):
936+
cpl = CoupleInputsDynamicShapes(ags, kws, ds)
937+
backed_size_oblivious = cpl.invalid_dimensions_for_export()
938+
self.assertFalse(backed_size_oblivious)
939+
868940

869941
if __name__ == "__main__":
870942
unittest.main(verbosity=2)

_unittests/ut_export/test_serialization.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def forward(self, cache):
3030
cache = self._get_cache()
3131
DYN = torch.export.Dim.DYNAMIC
3232
ds = {0: DYN, 1: DYN, 3: DYN}
33-
dynamic_shapes = ([[ds, ds], [ds, ds]],)
33+
dynamic_shapes = ([ds, ds, ds, ds],)
3434
with torch_export_patches(patch_transformers=True):
3535
exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes)
3636
self.assertNotEmpty(exp)
@@ -44,7 +44,7 @@ def forward(self, cache):
4444
cache = self._get_cache()
4545
flat_unflat = flatten_unflatten_for_dynamic_shapes(cache)
4646
s = string_type(flat_unflat, with_shape=True)
47-
self.assertEqual("#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]", s)
47+
self.assertEqual("#4[T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7]", s)
4848

4949
def test_dynamic_cache_bypass(self):
5050
class Model(torch.nn.Module):
@@ -55,7 +55,7 @@ def forward(self, cache):
5555
with torch_export_patches(patch_transformers=True):
5656
flat_unflat = flatten_unflatten_for_dynamic_shapes(cache)
5757
s = string_type(flat_unflat, with_shape=True)
58-
self.assertEqual("#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]", s)
58+
self.assertEqual("#4[T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7]", s)
5959

6060
def test_dynamic_cache_guess_static(self):
6161
class Model(torch.nn.Module):
@@ -65,7 +65,7 @@ def forward(self, cache):
6565
cache = self._get_cache()
6666
md = ModelInputs(Model(), [(cache,)])
6767
guessed = md.guess_dynamic_shapes()
68-
self.assertEqual(guessed, (([[{}, {}], [{}, {}]],), {}))
68+
self.assertEqual(guessed, (([{}, {}, {}, {}],), {}))
6969

7070
def test_dynamic_cache_guess_auto(self):
7171
class Model(torch.nn.Module):
@@ -77,7 +77,7 @@ def forward(self, cache):
7777
guessed = md.guess_dynamic_shapes(auto=True)
7878
AUTO = torch.export.Dim.AUTO
7979
ds = {i: AUTO for i in range(4)} # noqa: C420
80-
self.assertEqual(guessed, (([[ds, ds], [ds, ds]],), {}))
80+
self.assertEqual(guessed, (([ds, ds, ds, ds],), {}))
8181

8282
def test_dynamic_cache_guess_dynamic(self):
8383
class Model(torch.nn.Module):
@@ -88,18 +88,11 @@ def forward(self, cache):
8888
Model(), [(self._get_cache(),), (self._get_cache(bsize=3, nheads=5),)]
8989
)
9090
guessed = md.guess_dynamic_shapes()
91+
print("****", guessed)
9192
DYN = torch.export.Dim.DYNAMIC
9293
self.assertEqual(
94+
(([{0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}],), {}),
9395
guessed,
94-
(
95-
(
96-
[
97-
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
98-
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
99-
],
100-
),
101-
{},
102-
),
10396
)
10497

10598

0 commit comments

Comments
 (0)