11import unittest
22import torch
3- from onnx_diagnostic .ext_test_case import (
4- ExtTestCase ,
5- ignore_warnings ,
6- requires_transformers ,
7- requires_python ,
8- )
3+ from onnx_diagnostic .ext_test_case import ExtTestCase , ignore_warnings , requires_transformers
94from onnx_diagnostic .torch_models .llms import get_phi2
105from onnx_diagnostic .helpers import string_type
11- from onnx_diagnostic .torch_export_patches import bypass_export_some_errors
126
137
148class TestLlmPhi (ExtTestCase ):
@@ -29,23 +23,6 @@ def test_export_phi2_1(self):
2923 ep = torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds )
3024 assert ep
3125
32- @ignore_warnings (UserWarning )
33- @requires_python ((3 , 12 ))
34- def test_export_phi2_2_bypassed (self ):
35- data = get_phi2 (num_hidden_layers = 2 )
36- model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
37- self .assertEqual (
38- {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
39- )
40- with bypass_export_some_errors (patch_transformers = True ) as modificator :
41- inputs = modificator (inputs )
42- ep = torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
43- assert ep
44- with bypass_export_some_errors (patch_transformers = True ) as modificator :
45- inputs = modificator (inputs )
46- ep = torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
47- assert ep
48-
4926
5027if __name__ == "__main__" :
5128 unittest .main (verbosity = 2 )
0 commit comments