@@ -18,7 +18,10 @@ def test_onnx_export_tiny_llm(self):
1818 data = get_tiny_llm ()
1919 model , inputs = data ["model" ], data ["inputs" ]
2020 self .assertEqual ({"attention_mask" , "past_key_values" , "input_ids" }, set (inputs ))
21- with bypass_export_some_errors (patch_transformers = True , replace_dynamic_cache = True ):
21+ with bypass_export_some_errors (
22+ patch_transformers = True , replace_dynamic_cache = True
23+ ) as modificator :
24+ inputs = modificator (inputs )
2225 ep = torch .onnx .export (
2326 model ,
2427 (),
@@ -38,7 +41,10 @@ def test_onnx_export_tiny_llm_cdbg(self):
3841 data = get_tiny_llm ()
3942 model , inputs = data ["model" ], data ["inputs" ]
4043 self .assertEqual ({"attention_mask" , "past_key_values" , "input_ids" }, set (inputs ))
41- with bypass_export_some_errors (patch_transformers = True , replace_dynamic_cache = True ):
44+ with bypass_export_some_errors (
45+ patch_transformers = True , replace_dynamic_cache = True
46+ ) as modificator :
47+ inputs = modificator (inputs )
4248 onx = to_onnx (model , (), kwargs = inputs , dynamic_shapes = data ["dynamic_shapes" ])
4349 self .assert_onnx_disc (
4450 inspect .currentframe ().f_code .co_name , onx , model , inputs , verbose = 1
0 commit comments