Skip to content

Commit 64de535

Browse files
committed
fix example
1 parent ddda9d9 commit 64de535

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

_doc/examples/plot_export_tiny_llm_method_generate.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,49 @@ def generate_text(
138138
df = pandas.DataFrame(data)
139139
print(df)
140140

141+
# %%
142+
# Minimal script to export a LLM
143+
# ++++++++++++++++++++++++++++++
144+
#
145+
# The following lines are a condensed copy with less comments.
146+
147+
# from HuggingFace
148+
MODEL_NAME = "arnir0/Tiny-LLM"
149+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
150+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
151+
152+
# to export into onnx
153+
forward_replacement = method_to_onnx(
154+
model,
155+
method_name="forward",
156+
exporter="custom",
157+
filename="plot_export_tiny_llm_method_generate.onnx",
158+
patch_kwargs=dict(patch_transformers=True),
159+
verbose=0,
160+
convert_after_n_calls=3,
161+
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
162+
)
163+
164+
# from HuggingFace again
165+
prompt = "Continue: it rains..."
166+
inputs = tokenizer(prompt, return_tensors="pt")
167+
outputs = model.generate(
168+
input_ids=inputs["input_ids"],
169+
attention_mask=inputs["attention_mask"],
170+
max_length=50,
171+
temperature=1,
172+
top_k=50,
173+
top_p=0.95,
174+
do_sample=True,
175+
)
176+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
177+
print("prompt answer:", generated_text)
178+
179+
# to check discrepancies
180+
data = forward_replacement.check_discrepancies()
181+
df = pandas.DataFrame(data)
182+
print(df)
183+
141184

142185
# %%
143186
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

onnx_diagnostic/export/api.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,19 @@ def check_discrepancies(
750750
classes = [
751751
cls
752752
for cls in self._serialization_classes
753-
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
753+
if cls
754+
not in {
755+
int,
756+
float,
757+
bool,
758+
str,
759+
torch.Tensor,
760+
list,
761+
set,
762+
dict,
763+
torch.device,
764+
torch.dtype,
765+
}
754766
]
755767
if verbose:
756768
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")

0 commit comments

Comments
 (0)