Skip to content

Commit b9def44

Browse files
committed
add test for phi2
1 parent 32ef23e commit b9def44

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import copy
12
import unittest
23
import torch
34
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
5+
from onnx_diagnostic.helpers import string_type
46
from onnx_diagnostic.cache_helpers import make_dynamic_cache
57
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
68
bypass_export_some_errors,
79
)
10+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
811

912

1013
class TestOnnxExportErrors(ExtTestCase):
@@ -69,6 +72,97 @@ def forward(self, x, cache):
6972
got = mod(*inputs)
7073
self.assertEqualArray(expected, got)
7174

75+
@ignore_warnings(UserWarning)
76+
def test_phi2_export_module(self):
77+
data = get_untrained_model_with_inputs("microsoft/phi-2")
78+
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
79+
str_inputs = string_type(inputs, with_shape=True, with_min_max=True)
80+
inputs_copied = copy.deepcopy(inputs)
81+
expected = model(**inputs_copied)
82+
self.maxDiff = None
83+
self.assertEqual(str_inputs, string_type(inputs, with_shape=True, with_min_max=True))
84+
85+
# The cache is modified inplace, that's why, we copied it.
86+
self.assertNotEqual(
87+
string_type(inputs, with_shape=True, with_min_max=True),
88+
string_type(inputs_copied, with_shape=True, with_min_max=True),
89+
)
90+
inputs_copied = copy.deepcopy(inputs)
91+
self.assertEqual(
92+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
93+
)
94+
95+
with bypass_export_some_errors(patch_transformers=True):
96+
ep = torch.export.export(
97+
model,
98+
(),
99+
kwargs=inputs,
100+
dynamic_shapes=dyn_shapes,
101+
strict=False, # True works but then the it fails during the execution
102+
)
103+
mod = ep.module()
104+
inputs_copied = copy.deepcopy(inputs)
105+
self.assertEqual(
106+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
107+
)
108+
got = mod(**inputs_copied)
109+
self.assertEqualAny(expected, got)
110+
111+
inputs_copied = copy.deepcopy(inputs)
112+
self.assertEqual(
113+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
114+
)
115+
mod = ep.module()
116+
got = mod(**inputs_copied)
117+
self.assertEqualAny(expected, got)
118+
119+
@ignore_warnings(UserWarning)
120+
def test_phi2_export_interpreter(self):
121+
data = get_untrained_model_with_inputs("microsoft/phi-2")
122+
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
123+
str_inputs = string_type(inputs, with_shape=True, with_min_max=True)
124+
inputs_copied = copy.deepcopy(inputs)
125+
expected = model(**inputs_copied)
126+
self.maxDiff = None
127+
self.assertEqual(str_inputs, string_type(inputs, with_shape=True, with_min_max=True))
128+
129+
# The cache is modified inplace, that's why, we copied it.
130+
self.assertNotEqual(
131+
string_type(inputs, with_shape=True, with_min_max=True),
132+
string_type(inputs_copied, with_shape=True, with_min_max=True),
133+
)
134+
inputs_copied = copy.deepcopy(inputs)
135+
self.assertEqual(
136+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
137+
)
138+
139+
with bypass_export_some_errors(patch_transformers=True):
140+
ep = torch.export.export(
141+
model,
142+
(),
143+
kwargs=inputs,
144+
dynamic_shapes=dyn_shapes,
145+
strict=False, # True works but then the it fails during the execution
146+
)
147+
148+
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
149+
# CustomTracer.remove_unnecessary_slices(ep.graph)
150+
memorize = []
151+
152+
class MyInterpreter(torch.fx.Interpreter):
153+
def call_function(self, target, args, kwargs):
154+
res = super().call_function(target, args, kwargs)
155+
memorize.append((target, args, kwargs, res))
156+
return res
157+
158+
inputs_copied = copy.deepcopy(inputs)
159+
self.assertEqual(
160+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
161+
)
162+
args, _spec = torch.utils._pytree.tree_flatten(inputs_copied)
163+
got = MyInterpreter(ep.module()).run(*args)
164+
self.assertEqualAny(expected, got)
165+
72166

73167
if __name__ == "__main__":
74168
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,7 @@ def string_type(
477477
if isinstance(obj, torch.nn.Module):
478478
return f"{obj.__class__.__name__}(...)"
479479

480-
if isinstance(obj, torch.dtype):
481-
return f"{obj.__class__.__name__}({obj})"
482-
483-
if isinstance(obj, torch.memory_format):
480+
if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)):
484481
return f"{obj.__class__.__name__}({obj})"
485482

486483
if isinstance(obj, torch.utils._pytree.TreeSpec):

0 commit comments

Comments
 (0)