Skip to content

Commit 7f708d9

Browse files
committed
add a test to check patch
1 parent dbfd255 commit 7f708d9

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
4+
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.cache_helpers import make_dynamic_cache
6+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
7+
bypass_export_some_errors,
8+
)
9+
10+
11+
class TestOnnxExportErrors(ExtTestCase):
12+
@ignore_warnings(UserWarning)
13+
@hide_stdout()
14+
def test_export_dynamic_cache_update(self):
15+
for strict in self.subloop([True, False], verbose=1):
16+
17+
class SubModelCache(torch.nn.Module):
18+
def forward(self, cache):
19+
d = cache.__class__()
20+
d.update(cache.key_cache[0] + 1, cache.value_cache[0] + 2, 0)
21+
d.update(cache.key_cache[0] + 3, cache.value_cache[0] + 5, 1)
22+
return d
23+
24+
class SubModel(torch.nn.Module):
25+
def forward(self, x, cache):
26+
return x + cache.key_cache[0] + cache.value_cache[0]
27+
28+
class Model(torch.nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
self.sub = SubModel()
32+
self.subcache = SubModelCache()
33+
34+
def forward(self, x, cache):
35+
return self.sub(x, self.subcache(cache))
36+
37+
# no patch
38+
cache = make_dynamic_cache(
39+
[(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)]
40+
)
41+
model = Model()
42+
inputs = (torch.randn((5, 6, 5, 6)), cache)
43+
expected = model(*inputs)
44+
45+
DYN = torch.export.Dim.DYNAMIC
46+
ep = torch.export.export(
47+
model,
48+
inputs,
49+
dynamic_shapes=({0: DYN, 2: DYN}, [[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]]),
50+
strict=strict,
51+
)
52+
mod = ep.module()
53+
got = mod(*inputs)
54+
self.assertEqualArray(expected, got)
55+
56+
# patching
57+
with bypass_export_some_errors(patch_transformers=True):
58+
got = model(*inputs)
59+
self.assertEqualArray(expected, got)
60+
ep2 = torch.export.export(
61+
model,
62+
inputs,
63+
dynamic_shapes=(
64+
{0: DYN, 2: DYN},
65+
[[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]],
66+
),
67+
strict=strict,
68+
)
69+
mod = ep2.module()
70+
got = mod(*inputs)
71+
self.assertEqualArray(expected, got)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import glob
7+
import itertools
78
import logging
89
import os
910
import re
@@ -1094,3 +1095,18 @@ def assert_onnx_disc(
10941095
def _debug(self):
10951096
"Tells if DEBUG=1 is set up."
10961097
return os.environ.get("DEBUG") in BOOLEAN_VALUES
1098+
1099+
def subloop(self, *args, verbose: int = 0):
1100+
"Loops over elements and calls :meth:`unittests.TextCase.subTest`."
1101+
if len(args) == 1:
1102+
for it in args[0]:
1103+
with self.subTest(case=it):
1104+
if verbose:
1105+
print(f"[subloop] it={it!r}")
1106+
yield it
1107+
else:
1108+
for it in itertools.product(*args):
1109+
with self.subTest(case=it):
1110+
if verbose:
1111+
print(f"[subloop] it={it!r}")
1112+
yield it

0 commit comments

Comments
 (0)