Skip to content

Commit c850d43

Browse files
committed
add patches
1 parent 6831594 commit c850d43

File tree

10 files changed

+1174
-0
lines changed

10 files changed

+1174
-0
lines changed

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ API of onnx_diagnostic
77
:maxdepth: 1
88
:caption: submodules
99

10+
torch_export_patches/index
1011

1112
.. toctree::
1213
:maxdepth: 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
onnx_diagnostic.torch_export_patches
2+
====================================
3+
4+
.. automodule:: onnx_diagnostic.torch_export_patches
5+
:members:
6+
:no-undoc-members:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest
2+
from experimental_experiment.ext_test_case import (
3+
ExtTestCase,
4+
requires_torch,
5+
requires_transformers,
6+
skipif_ci_windows,
7+
ignore_warnings,
8+
)
9+
from onnx_diagnostic.helpers import string_type
10+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
11+
bypass_export_some_errors,
12+
)
13+
14+
15+
class TestOnnxExportErrors(ExtTestCase):
16+
@requires_transformers("4.49.999")
17+
@skipif_ci_windows("not working on Windows")
18+
@ignore_warnings(UserWarning)
19+
def test_pytree_flatten_mamba_cache(self):
20+
import torch
21+
import torch.utils._pytree as py_pytree
22+
from transformers.cache_utils import MambaCache
23+
24+
class _config:
25+
def __init__(self):
26+
self.intermediate_size = 8
27+
self.state_size = 16
28+
self.conv_kernel = 32
29+
self.num_hidden_layers = 64
30+
self.dtype = torch.float16
31+
32+
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
33+
34+
with bypass_export_some_errors():
35+
values, spec = py_pytree.tree_flatten(cache)
36+
cache2 = py_pytree.tree_unflatten(values, spec)
37+
self.assertEqual(cache.dtype, cache2.dtype)
38+
self.assertEqual(cache.max_batch_size, cache2.max_batch_size)
39+
self.assertEqual(cache.intermediate_size, cache2.intermediate_size)
40+
self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size)
41+
self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size)
42+
self.assertEqualArrayAny(cache.conv_states, cache2.conv_states)
43+
self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states)
44+
45+
@requires_transformers("4.43")
46+
@requires_torch("2.7")
47+
@skipif_ci_windows("not working on Windows")
48+
@ignore_warnings(UserWarning)
49+
def test_exportable_mamba_cache(self):
50+
import torch
51+
from transformers.models.mamba.modeling_mamba import MambaCache
52+
53+
class _config:
54+
def __init__(self):
55+
self.intermediate_size = 8
56+
self.state_size = 16
57+
self.conv_kernel = 32
58+
self.num_hidden_layers = 64
59+
self.dtype = torch.float16
60+
61+
class Model(torch.nn.Module):
62+
def forward(self, x: torch.Tensor, cache: MambaCache):
63+
x1 = cache.ssm_states[0] + x
64+
x2 = cache.conv_states[0][:, :, ::2] + x1
65+
return x2
66+
67+
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
68+
self.assertEqual(
69+
string_type(cache), "MambaCache(conv_states=[T10r3,...], ssm_states=[T10r3,...])"
70+
)
71+
x = torch.ones(2, 8, 16).to(torch.float16)
72+
model = Model()
73+
model(x, cache)
74+
75+
with bypass_export_some_errors():
76+
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
77+
torch.export.export(Model(), (x, cache))
78+
79+
@requires_transformers("4.49.999")
80+
@skipif_ci_windows("not working on Windows")
81+
@ignore_warnings(UserWarning)
82+
def test_exportable_mamba_cache_dynamic(self):
83+
import torch
84+
from transformers.models.mamba.modeling_mamba import MambaCache
85+
86+
class _config:
87+
def __init__(self):
88+
self.intermediate_size = 8
89+
self.state_size = 16
90+
self.conv_kernel = 32
91+
self.num_hidden_layers = 2
92+
self.dtype = torch.float16
93+
94+
class Model(torch.nn.Module):
95+
def forward(self, x: torch.Tensor, cache: MambaCache):
96+
x1 = cache.ssm_states[0] + x
97+
x2 = cache.conv_states[0][:, :, ::2] + x1
98+
return x2
99+
100+
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
101+
self.assertEqual(
102+
string_type(cache),
103+
"MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])",
104+
)
105+
x = torch.ones(2, 8, 16).to(torch.float16)
106+
model = Model()
107+
model(x, cache)
108+
DYN = torch.export.Dim.DYNAMIC
109+
110+
with bypass_export_some_errors():
111+
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
112+
torch.export.export(
113+
Model(),
114+
(x, cache),
115+
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
116+
)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main(verbosity=2)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .onnx_export_errors import (
2+
bypass_export_some_errors,
3+
register_additional_serialization_functions,
4+
)

0 commit comments

Comments
 (0)