88)
99from onnx_diagnostic .torch_models .llms import get_phi2
1010from onnx_diagnostic .helpers import string_type
11- from onnx_diagnostic .torch_export_patches import torch_export_patches
11+ from onnx_diagnostic .torch_export_patches import (
12+ torch_export_patches ,
13+ register_additional_serialization_functions ,
14+ )
1215from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
1316
1417
@@ -21,8 +24,8 @@ def test_get_phi2(self):
2124
2225 @ignore_warnings (UserWarning )
2326 @requires_transformers ("4.54" )
24- @requires_torch ("2.9 .99" )
25- def test_export_phi2_1_batch_size_1 (self ):
27+ @requires_torch ("2.10 .99" )
28+ def test_export_phi2_1_batch_size_1_oblivious (self ):
2629 # exporting vmap does not work
2730 data = get_phi2 (num_hidden_layers = 2 , batch_size = 1 )
2831 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
@@ -38,6 +41,40 @@ def test_export_phi2_1_batch_size_1(self):
3841 )
3942 assert ep
4043
44+ @ignore_warnings (UserWarning )
45+ @requires_transformers ("4.54" )
46+ @requires_torch ("2.9.99" )
47+ def test_export_phi2_1_batch_size_1_not_oblivious (self ):
48+ # exporting vmap does not work
49+ data = get_phi2 (num_hidden_layers = 2 , batch_size = 1 )
50+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
51+ self .assertEqual (inputs ["input_ids" ].shape [0 ], 1 )
52+ self .assertEqual (
53+ {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
54+ )
55+ with torch_export_patches (patch_transformers = True ):
56+ ep = torch .export .export (
57+ model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
58+ )
59+ assert ep
60+
61+ @ignore_warnings (UserWarning )
62+ @requires_transformers ("4.54" )
63+ @requires_torch ("2.12" )
64+ def test_export_phi2_1_batch_size_1_no_patch (self ):
65+ # exporting vmap does not work
66+ data = get_phi2 (num_hidden_layers = 2 , batch_size = 1 )
67+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
68+ self .assertEqual (inputs ["input_ids" ].shape [0 ], 1 )
69+ self .assertEqual (
70+ {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
71+ )
72+ with register_additional_serialization_functions (patch_transformers = True ):
73+ ep = torch .export .export (
74+ model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
75+ )
76+ assert ep
77+
4178 @ignore_warnings (UserWarning )
4279 @requires_transformers ("4.54" )
4380 @requires_torch ("2.9.99" )
0 commit comments