@@ -452,19 +452,18 @@ def forward(self, cache, z):
452452 (
453453 (
454454 [
455- [{}, {}],
456- [
457- {
458- 0 : torch .export .Dim .DYNAMIC ,
459- 2 : torch .export .Dim .DYNAMIC ,
460- 3 : torch .export .Dim .DYNAMIC ,
461- },
462- {
463- 0 : torch .export .Dim .DYNAMIC ,
464- 2 : torch .export .Dim .DYNAMIC ,
465- 3 : torch .export .Dim .DYNAMIC ,
466- },
467- ],
455+ {},
456+ {
457+ 0 : torch .export .Dim .DYNAMIC ,
458+ 2 : torch .export .Dim .DYNAMIC ,
459+ 3 : torch .export .Dim .DYNAMIC ,
460+ },
461+ {},
462+ {
463+ 0 : torch .export .Dim .DYNAMIC ,
464+ 2 : torch .export .Dim .DYNAMIC ,
465+ 3 : torch .export .Dim .DYNAMIC ,
466+ },
468467 ],
469468 {3 : torch .export .Dim .DYNAMIC },
470469 ),
@@ -520,11 +519,10 @@ def forward(self, cache, z):
520519 (
521520 (
522521 [
523- [{}, {}],
524- [
525- {0 : "dim_0I_1o_0l0" , 2 : "dim_0I_1o_0l2" , 3 : "dim_0I_1o_0l3" },
526- {0 : "dim_0I_1o_1l0" , 2 : "dim_0I_1o_1l2" , 3 : "dim_0I_1o_1l3" },
527- ],
522+ {},
523+ {0 : "dim_0I_1o0" , 2 : "dim_0I_1o2" , 3 : "dim_0I_1o3" },
524+ {},
525+ {0 : "dim_0I_3o0" , 2 : "dim_0I_3o2" , 3 : "dim_0I_3o3" },
528526 ],
529527 {3 : "dim_1I3" },
530528 ),
@@ -641,18 +639,18 @@ def test_couple_input_ds_cache(self):
641639 kwargs ,
642640 {
643641 "A" : ds_batch ,
644- "B" : (ds_batch , [[ ds_batch , ds_batch ], [ ds_batch , ds_batch ] ]),
642+ "B" : (ds_batch , [ds_batch , ds_batch , ds_batch , ds_batch ]),
645643 },
646644 ).invalid_dimensions_for_export (),
647645 )
648646 self .assertEqual (
649- {"B" : (None , [[ None , {2 : "d=[1]" }], [ None , {2 : "d=[1]" }] ])},
647+ {"B" : (None , [None , {2 : "d=[1]" }, None , {2 : "d=[1]" }])},
650648 Cls (
651649 (),
652650 kwargs ,
653651 {
654652 "A" : ds_batch ,
655- "B" : (ds_batch , [[ ds_batch , ds_batch_seq ], [ ds_batch , ds_batch_seq ] ]),
653+ "B" : (ds_batch , [ds_batch , ds_batch_seq , ds_batch , ds_batch_seq ]),
656654 },
657655 ).invalid_dimensions_for_export (),
658656 )
@@ -831,18 +829,17 @@ def test_dynamic_cache_replace_by_string(self):
831829
832830 DYN = torch .export .Dim .DYNAMIC
833831 ds = {
834- "cache" : [
835- [{0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN }],
836- [{0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN }],
837- ]
832+ "cache" : [{0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN }]
838833 }
839834 inst = CoupleInputsDynamicShapes ((), dict (cache = cache ), ds )
840835 as_string = inst .replace_by_string ()
841836 self .assertEqual (
842837 {
843838 "cache" : [
844- [{0 : "Dim0" , 1 : "Dim1" }, {0 : "Dim2" , 1 : "Dim3" }],
845- [{0 : "Dim4" , 1 : "Dim5" }, {0 : "Dim6" , 1 : "Dim7" }],
839+ {0 : "Dim0" , 1 : "Dim1" },
840+ {0 : "Dim2" , 1 : "Dim3" },
841+ {0 : "Dim4" , 1 : "Dim5" },
842+ {0 : "Dim6" , 1 : "Dim7" },
846843 ]
847844 },
848845 as_string ,
@@ -865,6 +862,81 @@ def test_unbatch_inputs(self):
865862 s ,
866863 )
867864
865+ def test_guess_dynamic_cache_without_patches (self ):
866+ n_layers = 2
867+ bsize , nheads , slen , dim = 2 , 4 , 3 , 7
868+ cache = make_dynamic_cache (
869+ [
870+ (torch .randn (bsize , nheads , slen , dim ), torch .randn (bsize , nheads , slen , dim ))
871+ for i in range (n_layers )
872+ ]
873+ )
874+ z = torch .randn ((1 , 1 , 1 , 7 ))
875+ cache2 = make_dynamic_cache (
876+ [
877+ (
878+ torch .randn (bsize + 1 , nheads , slen + 1 , dim + 1 ),
879+ torch .randn (bsize + 1 , nheads , slen + 1 , dim + 1 ),
880+ )
881+ for i in range (n_layers )
882+ ]
883+ )
884+ inputs = [
885+ (cache , z ),
886+ (cache2 , torch .randn ((1 , 1 , 1 , 8 ))),
887+ ]
888+
889+ class Model (torch .nn .Module ):
890+ def forward (self , cache , z ):
891+ cache = CacheKeyValue (cache )
892+ return (
893+ z
894+ + cache .key_cache [0 ]
895+ + cache .key_cache [1 ]
896+ + cache .value_cache [0 ]
897+ + cache .value_cache [1 ]
898+ )
899+
900+ mi = ModelInputs (Model (), inputs )
901+ ds = mi .guess_dynamic_shapes ()
902+ DYN = torch .export .Dim .DYNAMIC
903+ self .assertEqual (
904+ (
905+ (
906+ [
907+ {0 : DYN , 2 : DYN , 3 : DYN },
908+ {0 : DYN , 2 : DYN , 3 : DYN },
909+ {0 : DYN , 2 : DYN , 3 : DYN },
910+ {0 : DYN , 2 : DYN , 3 : DYN },
911+ ],
912+ {3 : DYN },
913+ ),
914+ {},
915+ ),
916+ ds ,
917+ )
918+
919+ def test_invalid_dimensions_for_export (self ):
920+ ags = []
921+ kws = dict (
922+ input_ids = torch .randint (0 , 10 , (2 , 3 )),
923+ attention_mask = torch .randint (0 , 1 , (2 , 33 )),
924+ position_ids = torch .randint (0 , 10 , (2 , 3 )),
925+ past_key_values = make_dynamic_cache (
926+ [torch .rand ((2 , 1 , 30 , 96 )), torch .rand ((2 , 1 , 30 , 96 ))]
927+ ),
928+ )
929+ ds = dict (
930+ input_ids = {0 : "batch" , 1 : "seq_length" },
931+ attention_mask = {0 : "batch" , 1 : "seq_length" },
932+ position_ids = {0 : "batch" , 1 : "seq_length" },
933+ past_key_values = [{0 : "batch" , 2 : "cache_length" }, {0 : "batch" , 2 : "cache_length" }],
934+ )
935+ with torch_export_patches (patch_transformers = True ):
936+ cpl = CoupleInputsDynamicShapes (ags , kws , ds )
937+ backed_size_oblivious = cpl .invalid_dimensions_for_export ()
938+ self .assertFalse (backed_size_oblivious )
939+
868940
869941if __name__ == "__main__" :
870942 unittest .main (verbosity = 2 )
0 commit comments