11import unittest
22import torch
33from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
4+ from onnx_diagnostic .helpers import max_diff
5+ from onnx_diagnostic .helpers .torch_helper import torch_deepcopy
6+ from onnx_diagnostic .helpers .rt_helper import make_feeds
7+ from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
8+ from onnx_diagnostic .torch_export_patches import torch_export_patches
49from onnx_diagnostic .export .api import to_onnx
510
611
@@ -19,16 +24,66 @@ def forward(self, x, y):
1924 (x , y ),
2025 dynamic_shapes = ds ,
2126 exporter = "custom" ,
22- filename = self .get_dump_file ("custom .onnx" ),
27+ filename = self .get_dump_file ("to_onnx_custom .onnx" ),
2328 )
2429 to_onnx (
2530 Model (),
2631 (x , y ),
2732 dynamic_shapes = ds ,
2833 exporter = "onnx-dynamo" ,
29- filename = self .get_dump_file ("onnx -dynamo.onnx" ),
34+ filename = self .get_dump_file ("to_onnx_onnx -dynamo.onnx" ),
3035 )
3136
37+ @hide_stdout ()
38+ def test_tiny_llm_to_onnx (self ):
39+ import onnxruntime
40+
41+ data = get_untrained_model_with_inputs ("arnir0/Tiny-LLM" )
42+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
43+ b1 = data ["inputs_batch1" ]
44+ filenames = {
45+ "custom" : self .get_dump_file ("test_tiny_llm_to_onnx-custom.onnx" ),
46+ "onnx-dynamo" : self .get_dump_file ("test_tiny_llm_to_onnx-dynamo.onnx" ),
47+ "modelbuilder" : self .get_dump_file ("model.onnx" ),
48+ }
49+ del inputs ["position_ids" ]
50+ del ds ["position_ids" ]
51+ del b1 ["position_ids" ]
52+
53+ expected = model (** torch_deepcopy (b1 ))
54+
55+ with torch_export_patches (patch_transformers = True ):
56+ for exporter , filename in filenames .items ():
57+ with self .subTest (exporter = exporter ):
58+ to_onnx (
59+ model ,
60+ kwargs = inputs ,
61+ dynamic_shapes = ds ,
62+ exporter = exporter ,
63+ filename = filename ,
64+ )
65+ for exporter , filename in filenames .items ():
66+ with self .subTest (exporter = f"validate-{ exporter } " ):
67+ sess = onnxruntime .InferenceSession (
68+ filename , providers = ["CPUExecutionProvider" ]
69+ )
70+ feeds = make_feeds (sess , b1 , use_numpy = True )
71+ got = sess .run (None , feeds )
72+ diff = max_diff (expected , got )
73+ assert diff ["abs" ] <= 1e-5 , f"diff={ diff } "
74+
75+ b1 ["attention_mask" ][:, :] = 1
76+ expected = model (** torch_deepcopy (b1 ))
77+ for exporter , filename in filenames .items ():
78+ with self .subTest (exporter = f"full-mask-{ exporter } " ):
79+ sess = onnxruntime .InferenceSession (
80+ filename , providers = ["CPUExecutionProvider" ]
81+ )
82+ feeds = make_feeds (sess , b1 , use_numpy = True )
83+ got = sess .run (None , feeds )
84+ diff = max_diff (expected , got )
85+ assert diff ["abs" ] <= 1e-5 , f"diff={ diff } "
86+
3287
3388if __name__ == "__main__" :
3489 unittest .main (verbosity = 2 )
0 commit comments