@@ -121,6 +121,49 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
121121 dynamic_shapes = ({0 : DYN }, [[{0 : DYN }, {0 : DYN }], [{0 : DYN }, {0 : DYN }]]),
122122 )
123123
124+ @ignore_warnings (UserWarning )
125+ def test_exportable_dynamic_shapes_constraints (self ):
126+ import torch
127+
128+ class CustomCache :
129+ def __init__ (self , shape = None ):
130+ self .cache = [torch .zeros ((shape )), torch .zeros ((shape ))] if shape else []
131+
132+ def flatten_cache (cache ):
133+ return [cache .cache ], ["cache" ]
134+
135+ def unflatten_cache (values , context , output_type = None ):
136+ cache = CustomCache ()
137+ cache .cache = values [0 ]
138+ return cache
139+
140+ def flatten_with_keys_cache (d ):
141+ values , context = flatten_cache (d )
142+ return [
143+ (torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )
144+ ], context
145+
146+ torch .utils ._pytree .register_pytree_node (
147+ CustomCache ,
148+ flatten_cache ,
149+ unflatten_cache ,
150+ serialized_type_name = f"{ CustomCache .__module__ } .{ CustomCache .__name__ } " ,
151+ flatten_with_keys_fn = flatten_with_keys_cache ,
152+ )
153+
154+ class Model (torch .nn .Module ):
155+ def forward (self , x , cache ):
156+ return cache .cache [0 ][0 , :] + x
157+
158+ model = Model ()
159+ model .eval ()
160+ x , cache = torch .rand ((2 , 4 )), CustomCache ((2 , 4 ))
161+ model (x , cache )
162+ DYN = torch .export .Dim .DYNAMIC
163+ torch .export .export (
164+ model , (x , cache ), dynamic_shapes = ({0 : DYN }, [[{0 : DYN }, {0 : DYN }]])
165+ )
166+
124167
125168if __name__ == "__main__" :
126169 unittest .main (verbosity = 2 )
0 commit comments