11import os
22import unittest
33import torch
4- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
4+ from onnx_diagnostic .ext_test_case import (
5+ ExtTestCase ,
6+ hide_stdout ,
7+ requires_transformers ,
8+ requires_torch ,
9+ )
510from onnx_diagnostic .helpers import max_diff , flatten_object
6- from onnx_diagnostic .helpers .rt_helper import onnx_generate
11+ from onnx_diagnostic .helpers .rt_helper import onnx_generate , make_empty_cache
712from onnx_diagnostic .helpers .torch_helper import torch_deepcopy
813from onnx_diagnostic .helpers .ort_session import InferenceSessionForTorch
914from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
@@ -21,16 +26,33 @@ def simple_generate_with_cache(
2126 max_new_tokens : int = 100 ,
2227 ):
2328 # First call: prefill
24- outputs = model (
25- input_ids ,
26- use_cache = True ,
27- attention_mask = torch .ones (
28- input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
29+ attention_mask = torch .ones (
30+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
31+ )
32+ feeds = {
33+ ** dict (zip (session .input_names [:2 ], [input_ids , attention_mask ])),
34+ ** make_empty_cache (
35+ input_ids .shape [0 ],
36+ session .input_names [2 :],
37+ session .input_shapes [2 :],
38+ session .input_types [2 :],
2939 ),
40+ }
41+ onnx_results = session .run (None , feeds )
42+
43+ outputs = model (input_ids , use_cache = True , attention_mask = attention_mask )
44+
45+ diff = max_diff (outputs , onnx_results )
46+ assert diff ["abs" ] <= 0.1 , (
47+ f"Unexpected issue with { type (model )} \n diff={ diff } "
48+ f"\n input_ids.shape={ input_ids .shape } "
49+ f"\n expected={ self .string_type (outputs , with_shape = True , with_min_max = True )} "
50+ f"\n got=\n "
51+ f"{ self .string_type (onnx_results , with_shape = True , with_min_max = True )} "
3052 )
3153
3254 # Next calls: decode
33- for _ in range (max_new_tokens ):
55+ for iteration in range (max_new_tokens ):
3456 next_token_logits = outputs .logits [:, - 1 , :]
3557 next_token_id = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
3658 if next_token_id .item () == eos_token_id :
@@ -42,11 +64,14 @@ def simple_generate_with_cache(
4264 feeds = dict (
4365 zip (
4466 session .input_names ,
45- torch_deepcopy (
46- flatten_object (
47- [next_token_id , attention_mask , outputs .past_key_values ]
67+ [
68+ t .detach ()
69+ for t in torch_deepcopy (
70+ flatten_object (
71+ [next_token_id , attention_mask , outputs .past_key_values ]
72+ )
4873 )
49- ) ,
74+ ] ,
5075 )
5176 )
5277 onnx_results = session .run (None , feeds )
@@ -57,9 +82,17 @@ def simple_generate_with_cache(
5782 attention_mask = attention_mask ,
5883 )
5984 diff = max_diff (outputs , onnx_results )
60- print ("****" , diff )
85+ assert diff ["abs" ] <= 0.1 , (
86+ f"Unexpected issue with { type (model )} , iteration={ iteration } "
87+ f"\n diff={ diff } \n input_ids.shape={ input_ids .shape } "
88+ f"\n expected={ self .string_type (outputs , with_shape = True , with_min_max = True )} "
89+ f"\n got=\n "
90+ f"{ self .string_type (onnx_results , with_shape = True , with_min_max = True )} "
91+ )
6192 return input_ids
6293
94+ @requires_transformers ("4.55" )
95+ @requires_torch ("2.9" )
6396 @hide_stdout ()
6497 def test_onnx_generate (self ):
6598 mid = "arnir0/Tiny-LLM"
@@ -83,25 +116,25 @@ def test_onnx_generate(self):
83116 exporter = "custom" ,
84117 )
85118
86- print ("-- test_onnx_generate: generate" )
87- res , session = onnx_generate (
88- model_name , input_ids [:1 ], 2 , max_new_tokens = 10 , return_session = True
89- )
90- n_inputs = input_ids .shape [1 ]
91- self .assertEqualArray (input_ids [:1 ], res [:, :n_inputs ])
92- self .assertEqual (res .dtype , torch .int64 )
93- self .assertEqual (res .shape , (1 , 13 ))
94- print ("-- test_onnx_generate: done" )
95- # expected = model.generate(input_ids[:1], max_new_tokens=10)
96- expected = self .simple_generate_with_cache (
97- model , input_ids [:1 ], 2 , max_new_tokens = 10 , session = session
98- )
99- self .assertEqualArray (input_ids [:1 ], expected [:, :n_inputs ])
100- print ("******" , res )
101- print ("******" , expected )
102- self .assertEqual (expected .dtype , torch .int64 )
103- self .assertEqual (expected .shape , (1 , 13 ))
104- self .assertEqualArray (expected , res )
119+ print ("-- test_onnx_generate: generate" )
120+ res , session = onnx_generate (
121+ model_name , input_ids [:1 ], 2 , max_new_tokens = 10 , return_session = True
122+ )
123+ n_inputs = input_ids .shape [1 ]
124+ self .assertEqualArray (input_ids [:1 ], res [:, :n_inputs ])
125+ self .assertEqual (res .dtype , torch .int64 )
126+ self .assertEqual (res .shape , (1 , 13 ))
127+ print ("-- test_onnx_generate: done" )
128+ # expected = model.generate(input_ids[:1], max_new_tokens=10)
129+ expected = self .simple_generate_with_cache (
130+ model , input_ids [:1 ], 2 , max_new_tokens = 10 , session = session
131+ )
132+ self .assertEqualArray (input_ids [:1 ], expected [:, :n_inputs ])
133+ print ("******" , res )
134+ print ("******" , expected )
135+ self .assertEqual (expected .dtype , torch .int64 )
136+ self .assertEqual (expected .shape , (1 , 13 ))
137+ self .assertEqualArray (expected , res )
105138
106139
107140if __name__ == "__main__" :
0 commit comments