@@ -138,6 +138,49 @@ def generate_text(
138138df = pandas .DataFrame (data )
139139print (df )
140140
141+ # %%
142+ # Minimal script to export a LLM
143+ # ++++++++++++++++++++++++++++++
144+ #
145+ # The following lines are a condensed copy with less comments.
146+
147+ # from HuggingFace
148+ MODEL_NAME = "arnir0/Tiny-LLM"
149+ tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
150+ model = AutoModelForCausalLM .from_pretrained (MODEL_NAME )
151+
152+ # to export into onnx
153+ forward_replacement = method_to_onnx (
154+ model ,
155+ method_name = "forward" ,
156+ exporter = "custom" ,
157+ filename = "plot_export_tiny_llm_method_generate.onnx" ,
158+ patch_kwargs = dict (patch_transformers = True ),
159+ verbose = 0 ,
160+ convert_after_n_calls = 3 ,
161+ dynamic_batch_for = {"input_ids" , "attention_mask" , "past_key_values" },
162+ )
163+
164+ # from HuggingFace again
165+ prompt = "Continue: it rains..."
166+ inputs = tokenizer (prompt , return_tensors = "pt" )
167+ outputs = model .generate (
168+ input_ids = inputs ["input_ids" ],
169+ attention_mask = inputs ["attention_mask" ],
170+ max_length = 50 ,
171+ temperature = 1 ,
172+ top_k = 50 ,
173+ top_p = 0.95 ,
174+ do_sample = True ,
175+ )
176+ generated_text = tokenizer .decode (outputs [0 ], skip_special_tokens = True )
177+ print ("prompt answer:" , generated_text )
178+
179+ # to check discrepancies
180+ data = forward_replacement .check_discrepancies ()
181+ df = pandas .DataFrame (data )
182+ print (df )
183+
141184
142185# %%
143186doc .save_fig (doc .plot_dot (filename ), f"{ filename } .png" , dpi = 400 )
0 commit comments