Skip to content

Commit 7d65218

Browse files
committed
custom cache
1 parent 0e8155c commit 7d65218

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,49 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
121121
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
122122
)
123123

124+
@ignore_warnings(UserWarning)
125+
def test_exportable_dynamic_shapes_constraints(self):
126+
import torch
127+
128+
class CustomCache:
129+
def __init__(self, shape=None):
130+
self.cache = [torch.zeros((shape)), torch.zeros((shape))] if shape else []
131+
132+
def flatten_cache(cache):
133+
return [cache.cache], ["cache"]
134+
135+
def unflatten_cache(values, context, output_type=None):
136+
cache = CustomCache()
137+
cache.cache = values[0]
138+
return cache
139+
140+
def flatten_with_keys_cache(d):
141+
values, context = flatten_cache(d)
142+
return [
143+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
144+
], context
145+
146+
torch.utils._pytree.register_pytree_node(
147+
CustomCache,
148+
flatten_cache,
149+
unflatten_cache,
150+
serialized_type_name=f"{CustomCache.__module__}.{CustomCache.__name__}",
151+
flatten_with_keys_fn=flatten_with_keys_cache,
152+
)
153+
154+
class Model(torch.nn.Module):
155+
def forward(self, x, cache):
156+
return cache.cache[0][0, :] + x
157+
158+
model = Model()
159+
model.eval()
160+
x, cache = torch.rand((2, 4)), CustomCache((2, 4))
161+
model(x, cache)
162+
DYN = torch.export.Dim.DYNAMIC
163+
torch.export.export(
164+
model, (x, cache), dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}]])
165+
)
166+
124167

125168
if __name__ == "__main__":
126169
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)