@@ -24,12 +24,15 @@ def test_get_phi2(self):
2424 @requires_torch ("2.9.99" )
2525 def test_export_phi2_1_batch_size_1 (self ):
2626 # exporting vmap does not work
27- data = get_phi2 (num_hidden_layers = 2 )
27+ data = get_phi2 (num_hidden_layers = 2 , batch_size = 1 )
2828 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
29+ self .assertEqual (inputs ["input_ids" ].shape [0 ], 1 )
2930 self .assertEqual (
3031 {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
3132 )
32- with torch_export_patches (patch_transformers = True ):
33+ with torch .fx .experimental ._config .patch (
34+ backed_size_oblivious = True
35+ ), torch_export_patches (patch_transformers = True ):
3336 ep = torch .export .export (
3437 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
3538 )
@@ -40,8 +43,9 @@ def test_export_phi2_1_batch_size_1(self):
4043 @requires_torch ("2.9.99" )
4144 def test_export_phi2_1_batch_size_2 (self ):
4245 # exporting vmap does not work
43- data = get_phi2 (num_hidden_layers = 2 , batch = 2 )
46+ data = get_phi2 (num_hidden_layers = 2 , batch_size = 2 )
4447 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
48+ self .assertEqual (inputs ["input_ids" ].shape [0 ], 2 )
4549 self .assertEqual (
4650 {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
4751 )
0 commit comments