22import unittest
33import torch
44from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
5+ from onnx_diagnostic .helpers import max_diff , flatten_object
56from onnx_diagnostic .helpers .rt_helper import onnx_generate
7+ from onnx_diagnostic .helpers .torch_helper import torch_deepcopy
8+ from onnx_diagnostic .helpers .ort_session import InferenceSessionForTorch
69from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
710from onnx_diagnostic .torch_export_patches import torch_export_patches
811from onnx_diagnostic .export .api import to_onnx
912
1013
1114class TestRtSession (ExtTestCase ):
1215 def simple_generate_with_cache (
13- self , model , input_ids : torch .Tensor , eos_token_id : int , max_new_tokens : int = 100
16+ self ,
17+ model ,
18+ input_ids : torch .Tensor ,
19+ eos_token_id : int ,
20+ session : InferenceSessionForTorch ,
21+ max_new_tokens : int = 100 ,
1422 ):
1523 # First call: prefill
1624 outputs = model (
1725 input_ids ,
26+ use_cache = True ,
1827 attention_mask = torch .ones (
1928 input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
2029 ),
21- use_cache = True ,
2230 )
2331
2432 # Next calls: decode
2533 for _ in range (max_new_tokens ):
2634 next_token_logits = outputs .logits [:, - 1 , :]
27- past_key_values = outputs .past_key_values
2835 next_token_id = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
2936 if next_token_id .item () == eos_token_id :
3037 break
3138 input_ids = torch .cat ([input_ids , next_token_id ], dim = - 1 )
39+ attention_mask = torch .ones (
40+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
41+ )
42+ feeds = dict (
43+ zip (
44+ session .input_names ,
45+ torch_deepcopy (
46+ flatten_object (
47+ [next_token_id , attention_mask , outputs .past_key_values ]
48+ )
49+ ),
50+ )
51+ )
52+ onnx_results = session .run (None , feeds )
3253 outputs = model (
3354 next_token_id ,
3455 use_cache = True ,
35- past_key_values = past_key_values ,
36- attention_mask = torch .ones (
37- input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
38- ),
56+ past_key_values = outputs .past_key_values ,
57+ attention_mask = attention_mask ,
3958 )
59+ diff = max_diff (outputs , onnx_results )
60+ print ("****" , diff )
4061 return input_ids
4162
4263 @hide_stdout ()
@@ -63,14 +84,18 @@ def test_onnx_generate(self):
6384 )
6485
6586 print ("-- test_onnx_generate: generate" )
66- res = onnx_generate (model_name , input_ids [:1 ], 2 , max_new_tokens = 10 )
87+ res , session = onnx_generate (
88+ model_name , input_ids [:1 ], 2 , max_new_tokens = 10 , return_session = True
89+ )
6790 n_inputs = input_ids .shape [1 ]
6891 self .assertEqualArray (input_ids [:1 ], res [:, :n_inputs ])
6992 self .assertEqual (res .dtype , torch .int64 )
7093 self .assertEqual (res .shape , (1 , 13 ))
7194 print ("-- test_onnx_generate: done" )
7295 # expected = model.generate(input_ids[:1], max_new_tokens=10)
73- expected = self .simple_generate_with_cache (model , input_ids [:1 ], 2 , max_new_tokens = 10 )
96+ expected = self .simple_generate_with_cache (
97+ model , input_ids [:1 ], 2 , max_new_tokens = 10 , session = session
98+ )
7499 self .assertEqualArray (input_ids [:1 ], expected [:, :n_inputs ])
75100 print ("******" , res )
76101 print ("******" , expected )
0 commit comments