@@ -862,6 +862,60 @@ def test_unbatch_inputs(self):
862862 s ,
863863 )
864864
865+ def test_guess_dynamic_cache_without_patches (self ):
866+ n_layers = 2
867+ bsize , nheads , slen , dim = 2 , 4 , 3 , 7
868+ cache = make_dynamic_cache (
869+ [
870+ (torch .randn (bsize , nheads , slen , dim ), torch .randn (bsize , nheads , slen , dim ))
871+ for i in range (n_layers )
872+ ]
873+ )
874+ z = torch .randn ((1 , 1 , 1 , 7 ))
875+ cache2 = make_dynamic_cache (
876+ [
877+ (
878+ torch .randn (bsize + 1 , nheads , slen + 1 , dim + 1 ),
879+ torch .randn (bsize + 1 , nheads , slen + 1 , dim + 1 ),
880+ )
881+ for i in range (n_layers )
882+ ]
883+ )
884+ inputs = [
885+ (cache , z ),
886+ (cache2 , torch .randn ((1 , 1 , 1 , 8 ))),
887+ ]
888+
889+ class Model (torch .nn .Module ):
890+ def forward (self , cache , z ):
891+ cache = CacheKeyValue (cache )
892+ return (
893+ z
894+ + cache .key_cache [0 ]
895+ + cache .key_cache [1 ]
896+ + cache .value_cache [0 ]
897+ + cache .value_cache [1 ]
898+ )
899+
900+ mi = ModelInputs (Model (), inputs )
901+ ds = mi .guess_dynamic_shapes ()
902+ DYN = torch .export .Dim .DYNAMIC
903+ self .assertEqual (
904+ (
905+ (
906+ [
907+ {0 : DYN , 2 : DYN , 3 : DYN },
908+ {0 : DYN , 2 : DYN , 3 : DYN },
909+ {0 : DYN , 2 : DYN , 3 : DYN },
910+ {0 : DYN , 2 : DYN , 3 : DYN },
911+ ],
912+ {3 : DYN },
913+ ),
914+ {},
915+ ),
916+ ds ,
917+ )
918+
865919
866920if __name__ == "__main__" :
867921 unittest .main (verbosity = 2 )
0 commit comments