1414epkg:`microsoft/Phi-1.5` is a small LLM. The example given
1515"""
1616
17+ import os
1718import time
19+ import sys
1820import pandas
1921from tqdm import tqdm
22+ import torch
23+ from transformers import AutoModelForCausalLM , AutoTokenizer
2024from onnx_diagnostic .ext_test_case import unit_test_going
2125from onnx_diagnostic .helpers import string_type
26+ from onnx_diagnostic .helpers .torch_helper import to_any , get_weight_type
27+ from onnx_diagnostic .helpers .rt_helper import onnx_generate
28+ from onnx_diagnostic .torch_export_patches import torch_export_patches
2229from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
23- import torch
24- from transformers import AutoModelForCausalLM , AutoTokenizer
30+ from onnx_diagnostic .torch_models .hghub .hub_api import get_pretrained_config , task_from_id
31+ from onnx_diagnostic .tasks import random_input_kwargs
32+ from onnx_diagnostic .export .api import to_onnx
33+
2534
26- device = "cuda" if torch .cuda .is_available else "cpu"
35+ device = "cuda" if torch .cuda .is_available () else "cpu"
2736data = []
2837
2938print ("-- load the model..." )
30- # unit_test_going() returns True if UNITTEST_GOING is 1
3139if unit_test_going ():
40+ # unit_test_going() returns True if UNITTEST_GOING is 1
41+ # The example switches to a faster scenario.
3242 model_id = "arnir0/Tiny-LLM"
33- model = get_untrained_model_with_inputs (model_id )["model" ]
43+ data_export = get_untrained_model_with_inputs (model_id )
44+ model = data_export ["model" ]
45+ export_inputs = data_export ["inputs" ]
46+ export_shapes = data_export ["dynamic_shapes" ]
3447 tokenizer = AutoTokenizer .from_pretrained (model_id )
3548else :
3649 model_id = "microsoft/phi-1_5"
3750 model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = "auto" )
3851 tokenizer = AutoTokenizer .from_pretrained (model_id )
52+ config = get_pretrained_config (model_id )
53+ task = task = task_from_id (model_id )
54+ kwargs , fct = random_input_kwargs (config , task )
55+ res = fct (model , config , add_second_input = False , ** kwargs )
56+ export_inputs = res ["inputs" ]
57+ export_shapes = res ["dynamic_shapes" ]
3958model = model .to (device )
4059print ("-- done." )
4160
5271
5372print ("-- compute the answer..." )
5473begin = time .perf_counter ()
55- outputs = model .generate (** inputs , max_length = 100 )
74+ outputs = model .generate (** inputs , max_new_tokens = 100 )
5675duration = time .perf_counter () - begin
5776print (f"-- done in { duration } " )
5877data .append (dict (name = "generate" , duration = duration ))
59- print ("output shape:" , string_type (outputs , with_shape = True ))
78+ print ("output shape:" , string_type (outputs , with_shape = True , with_min_max = True ))
6079print ("-- decode the answer..." )
6180text = tokenizer .batch_decode (outputs )[0 ]
6281print ("-- done." )
7998def simple_generate_with_cache (
8099 model , input_ids : torch .Tensor , eos_token_id : int , max_new_tokens : int = 100
81100):
82- answer = []
83- # First call.
101+ # First call: prefill
84102 outputs = model (input_ids , use_cache = True )
85- next_token_logits = outputs .logits [:, - 1 , :]
86- past_key_values = outputs .past_key_values
87103
88- # Next calls.
104+ # Next calls: decode
89105 for _ in tqdm (list (range (max_new_tokens ))):
106+ next_token_logits = outputs .logits [:, - 1 , :]
107+ past_key_values = outputs .past_key_values
108+
90109 # The most probable next token is chosen.
91110 next_token_id = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
92111 # But we could select it using a multinomial law
93112 # <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
94113 # <<< top_probs, top_indices = torch.topk(probs, top_k)
95114 # <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
96115
97- # Let's add the predicted token to the answer.
98- answer .append (next_token_id )
116+ if next_token_id .item () == eos_token_id :
117+ break
118+ input_ids = torch .cat ([input_ids , next_token_id ], dim = - 1 )
99119
100120 # Feed only the new token, but with the cache
101121 outputs = model (next_token_id , use_cache = True , past_key_values = past_key_values )
102- next_token_logits = outputs .logits [:, - 1 , :]
103- past_key_values = outputs .past_key_values
104-
105- input_ids = torch .cat ([input_ids , next_token_id ], dim = - 1 )
106122
107- if next_token_id .item () == eos_token_id :
108- break
109-
110- return torch .cat (answer , dim = 1 )
123+ return input_ids
111124
112125
113126print ("-- compute the answer with custom generate..." )
@@ -120,12 +133,77 @@ def simple_generate_with_cache(
120133data .append (dict (name = "custom" , duration = duration ))
121134
122135print ("-- done." )
123- print ("output shape:" , string_type (outputs , with_shape = True ))
136+ print ("output shape:" , string_type (outputs , with_shape = True , with_min_max = True ))
137+ print ("-- decode the answer..." )
138+ text = tokenizer .batch_decode (outputs )[0 ]
139+ print ("-- done." )
140+ print (text )
141+
142+ # %%
143+ # Method generate for onnx models
144+ # ===============================
145+ #
146+ # We first need to export the model into ONNX.
147+ #
148+ # ONNX Conversion
149+ # +++++++++++++++
150+
151+ if "position_ids" in export_inputs :
152+ del export_inputs ["position_ids" ]
153+ del export_shapes ["position_ids" ]
154+ dtype = get_weight_type (model )
155+ print ("-- model dtype:" , dtype )
156+ export_inputs ["past_key_values" ] = to_any (export_inputs ["past_key_values" ], dtype )
157+ exporter = "custom" if "custom" in sys .argv else "onnx-dynamo"
158+ model_name = f"model_{ model_id .replace ('/' , '-' )} .{ exporter } .onnx"
159+ if not os .path .exists (model_name ):
160+ # This step is slow so let's skip it if it was already done.
161+ print ("-- conversion to ONNX." )
162+ begin = time .perf_counter ()
163+ with torch_export_patches (patch_transformers = True ):
164+ to_onnx (
165+ model ,
166+ (),
167+ kwargs = to_any (export_inputs , device ),
168+ dynamic_shapes = export_shapes ,
169+ filename = model_name ,
170+ verbose = 1 ,
171+ exporter = exporter ,
172+ )
173+ duration = time .perf_counter () - begin
174+ print (f"-- done in { duration } " )
175+
176+ # %%
177+ # onnx_generate
178+ # +++++++++++++
179+ #
180+ # Then we can call method generate for two tokens.
181+ # This function is part of :epkg:`onnx_diagnostic` but follows the implementation
182+ # seen earlier for a torch model.
183+ # Let's ask first the function to return the session to avoid creating on the second call.
184+
185+ _res , session = onnx_generate (
186+ model_name , inputs .input_ids , 2 , max_new_tokens = 2 , return_session = True
187+ )
188+
189+ # And now the full answer.
190+ print ("-- compute the answer with custom generate..." )
191+ begin = time .perf_counter ()
192+ outputs = onnx_generate (
193+ session , inputs .input_ids , eos_token_id = tokenizer .eos_token_id , max_new_tokens = 100
194+ )
195+ duration = time .perf_counter () - begin
196+ print (f"-- done in { duration } " )
197+ data .append (dict (name = "onnx" , duration = duration ))
198+
199+ print ("-- done." )
200+ print ("output shape:" , string_type (outputs , with_shape = True , with_min_max = True ))
124201print ("-- decode the answer..." )
125202text = tokenizer .batch_decode (outputs )[0 ]
126203print ("-- done." )
127204print (text )
128205
206+
129207# %%
130208# Plots
131209# =====
0 commit comments