@@ -22,7 +22,7 @@ def test_get_phi2(self):
2222 @ignore_warnings (UserWarning )
2323 @requires_transformers ("4.54" )
2424 @requires_torch ("2.9.99" )
25- def test_export_phi2_1 (self ):
25+ def test_export_phi2_1_batch_size_1 (self ):
2626 # exporting vmap does not work
2727 data = get_phi2 (num_hidden_layers = 2 )
2828 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
@@ -35,6 +35,22 @@ def test_export_phi2_1(self):
3535 )
3636 assert ep
3737
38+ @ignore_warnings (UserWarning )
39+ @requires_transformers ("4.54" )
40+ @requires_torch ("2.9.99" )
41+ def test_export_phi2_1_batch_size_2 (self ):
42+ # exporting vmap does not work
43+ data = get_phi2 (num_hidden_layers = 2 , batch = 2 )
44+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
45+ self .assertEqual (
46+ {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
47+ )
48+ with torch_export_patches (patch_transformers = True ):
49+ ep = torch .export .export (
50+ model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
51+ )
52+ assert ep
53+
3854
3955if __name__ == "__main__" :
4056 unittest .main (verbosity = 2 )
0 commit comments