55from onnx_diagnostic .helpers .rt_helper import onnx_generate
66from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
77from onnx_diagnostic .torch_export_patches import torch_export_patches
8+ from onnx_diagnostic .export .api import to_onnx
89
910
1011class TestRtSession (ExtTestCase ):
12+ def simple_generate_with_cache (
13+ self , model , input_ids : torch .Tensor , eos_token_id : int , max_new_tokens : int = 100
14+ ):
15+ # First call: prefill
16+ outputs = model (
17+ input_ids ,
18+ attention_mask = torch .ones (
19+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
20+ ),
21+ use_cache = True ,
22+ )
23+
24+ # Next calls: decode
25+ for _ in range (max_new_tokens ):
26+ next_token_logits = outputs .logits [:, - 1 , :]
27+ past_key_values = outputs .past_key_values
28+ next_token_id = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
29+ if next_token_id .item () == eos_token_id :
30+ break
31+ input_ids = torch .cat ([input_ids , next_token_id ], dim = - 1 )
32+ outputs = model (
33+ next_token_id ,
34+ use_cache = True ,
35+ past_key_values = past_key_values ,
36+ attention_mask = torch .ones (
37+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
38+ ),
39+ )
40+ return input_ids
41+
1142 @hide_stdout ()
1243 def test_onnx_generate (self ):
13- from experimental_experiment .torch_interpreter import to_onnx
14-
1544 mid = "arnir0/Tiny-LLM"
1645 print ("-- test_onnx_generate: get model" )
1746 data = get_untrained_model_with_inputs (mid )
1847 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
1948 del inputs ["position_ids" ]
2049 del ds ["position_ids" ]
2150 input_ids = inputs ["input_ids" ]
51+ print ("----" , input_ids .shape )
2252 folder = self .get_dump_folder ("test_onnx_generate" )
2353 model_name = os .path .join (folder , "model.onnx" )
2454 print ("-- test_onnx_generate: export model" )
@@ -29,13 +59,24 @@ def test_onnx_generate(self):
2959 kwargs = inputs ,
3060 dynamic_shapes = ds ,
3161 filename = model_name ,
62+ exporter = "custom" ,
3263 )
3364
3465 print ("-- test_onnx_generate: generate" )
3566 res = onnx_generate (model_name , input_ids [:1 ], 2 , max_new_tokens = 10 )
67+ n_inputs = input_ids .shape [1 ]
68+ self .assertEqualArray (input_ids [:1 ], res [:, :n_inputs ])
3669 self .assertEqual (res .dtype , torch .int64 )
3770 self .assertEqual (res .shape , (1 , 13 ))
3871 print ("-- test_onnx_generate: done" )
72+ # expected = model.generate(input_ids[:1], max_new_tokens=10)
73+ expected = self .simple_generate_with_cache (model , input_ids [:1 ], 2 , max_new_tokens = 10 )
74+ self .assertEqualArray (input_ids [:1 ], expected [:, :n_inputs ])
75+ print ("******" , res )
76+ print ("******" , expected )
77+ self .assertEqual (expected .dtype , torch .int64 )
78+ self .assertEqual (expected .shape , (1 , 13 ))
79+ self .assertEqualArray (expected , res )
3980
4081
4182if __name__ == "__main__" :
0 commit comments