Skip to content

Commit 2d34648

Browse files
authored
update version (#384)
* upgrade version * doc * fix example * bug * fix * doc
1 parent 7a28903 commit 2d34648

File tree

7 files changed

+88
-17
lines changed

7 files changed

+88
-17
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.8.10
5+
++++++
6+
7+
* :pr:`384`: add ``weights_only=False`` when using :func:`torch.load`
8+
49
0.8.9
510
+++++
611

_doc/examples/plot_export_tiny_llm_method_generate.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def generate_text(
4848
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
4949
return generated_text
5050

51-
# Define your prompt
5251

53-
54-
prompt = "Continue: it rains..."
52+
# Define your prompt
53+
prompt = "Continue: it rains, what should I do?"
5554
generated_text = generate_text(prompt, model, tokenizer)
5655
print("-----------------")
5756
print(generated_text)
@@ -69,7 +68,7 @@ def generate_text(
6968
# If the default settings do not work, ``skip_kwargs_names`` and ``dynamic_shapes``
7069
# can be changed to remove some undesired inputs or add more dynamic dimensions.
7170

72-
filename = "plot_export_tiny_llm_method_generate.onnx"
71+
filename = "plot_export_tiny_llm_method_generate.custom.onnx"
7372
forward_replacement = method_to_onnx(
7473
model,
7574
method_name="forward", # default value
@@ -87,8 +86,12 @@ def generate_text(
8786
# The input used in the example has a batch size equal to 1, all
8887
# inputs going through method forward will have the same batch size.
8988
# To force the dynamism of this dimension, we need to indicate
90-
# which inputs has a batch size.
89+
# which inputs have a batch size.
9190
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
91+
# Earlier versions of pytorch did not accept a dynamic batch size equal to 1,
92+
# this last parameter can be added to expand some inputs if the batch size is 1.
93+
# The exporter should work without.
94+
expand_batch_for={"input_ids", "attention_mask", "past_key_values"},
9295
)
9396

9497
# %%
@@ -139,6 +142,51 @@ def generate_text(
139142
df = pandas.DataFrame(data)
140143
print(df)
141144

145+
# %%
146+
# Minimal script to export a LLM
147+
# ++++++++++++++++++++++++++++++
148+
#
149+
# The following lines are a condensed copy with less comments.
150+
151+
# from HuggingFace
152+
print("----------------")
153+
MODEL_NAME = "arnir0/Tiny-LLM"
154+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
155+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
156+
157+
# to export into onnx
158+
forward_replacement = method_to_onnx(
159+
model,
160+
method_name="forward",
161+
exporter="onnx-dynamo",
162+
filename="plot_export_tiny_llm_method_generate.dynamo.onnx",
163+
patch_kwargs=dict(patch_transformers=True),
164+
verbose=0,
165+
convert_after_n_calls=3,
166+
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
167+
)
168+
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
169+
170+
# from HuggingFace again
171+
prompt = "Continue: it rains, what should I do?"
172+
inputs = tokenizer(prompt, return_tensors="pt")
173+
outputs = model.generate(
174+
input_ids=inputs["input_ids"],
175+
attention_mask=inputs["attention_mask"],
176+
max_length=100,
177+
temperature=1,
178+
top_k=50,
179+
top_p=0.95,
180+
do_sample=True,
181+
)
182+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
183+
print("prompt answer:", generated_text)
184+
185+
# to check discrepancies
186+
data = forward_replacement.check_discrepancies()
187+
df = pandas.DataFrame(data)
188+
print(df)
189+
142190

143191
# %%
144192
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

_doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ The function replaces dynamic dimensions defined as strings by
240240
Older versions
241241
==============
242242

243+
* `0.8.10 <../v0.8.10/index.html>`_
243244
* `0.8.9 <../v0.8.9/index.html>`_
244-
* `0.8.8 <../v0.8.8/index.html>`_
245245
* `0.7.16 <../v0.7.16/index.html>`_
246246
* `0.6.3 <../v0.6.3/index.html>`_
247247
* `0.5.0 <../v0.5.0/index.html>`_

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.8.9"
6+
__version__ = "0.8.10"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/export/api.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,6 @@ def forward(self, *args, **kwargs):
445445
and not isinstance(v, (bool, int, float))
446446
}
447447
)
448-
if self.expand_batch_for:
449-
# extends the inputs to artificially create a batch dimension != 1.
450-
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
451-
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
452448
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
453449
# reorders the parameter following the method signature.
454450
inp_kwargs = self._reorder_kwargs(inp_kwargs)
@@ -557,6 +553,10 @@ def __init__(self, parent):
557553
else:
558554
a, kw = self._inputs[-1]
559555
nds = [self.dynamic_shapes]
556+
if self.expand_batch_for:
557+
# extends the inputs to artificially create a batch dimension != 1.
558+
a = self._expand_batch_dimension(a, self.expand_batch_for)
559+
kw = self._expand_batch_dimension(kw, self.expand_batch_for)
560560
if self.verbose:
561561
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
562562
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
@@ -738,7 +738,9 @@ def check_discrepancies(
738738
:param verbose: verbosity
739739
:return: results, a list of dictionaries, ready to be consumed by a dataframe
740740
"""
741-
assert self._export_done, "The onnx export was not done."
741+
assert (
742+
self._export_done
743+
), f"The onnx export was not done, only {len(self._inputs)} were stored."
742744
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
743745
assert os.path.exists(
744746
self._output_file
@@ -750,17 +752,29 @@ def check_discrepancies(
750752
classes = [
751753
cls
752754
for cls in self._serialization_classes
753-
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
755+
if cls
756+
not in {
757+
int,
758+
float,
759+
bool,
760+
str,
761+
torch.Tensor,
762+
list,
763+
set,
764+
dict,
765+
torch.device,
766+
torch.dtype,
767+
}
754768
]
755769
if verbose:
756770
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
757771
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
758772
with torch.serialization.safe_globals(classes):
759-
inputs = torch.load(self._input_file)
773+
inputs = torch.load(self._input_file, weights_only=False)
760774
if verbose:
761775
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
762776
with torch.serialization.safe_globals(classes):
763-
outputs = torch.load(self._output_file)
777+
outputs = torch.load(self._output_file, weights_only=False)
764778
assert len(inputs) == len(outputs), (
765779
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
766780
f"inputs={string_type(inputs, with_shape=True)}, "

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def serialization_functions(
305305

306306

307307
def unregister_class_serialization(cls: type, verbose: int = 0):
308-
"""Undo the registration."""
308+
"""Undo the registration for a class."""
309309
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
310310
if cls in torch.fx._pytree.SUPPORTED_NODES:
311311
del torch.fx._pytree.SUPPORTED_NODES[cls]
@@ -333,6 +333,10 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
333333

334334

335335
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
336+
"""
337+
Undo the registration made by
338+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`.
339+
"""
336340
cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
337341
for cls in cls_ensemble:
338342
if undo.get(cls.__name__, False):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "onnx-diagnostic"
3-
version = "0.8.9"
3+
version = "0.8.10"
44
description = "Tools to help converting pytorch models into ONNX."
55
readme = "README.rst"
66
authors = [

0 commit comments

Comments
 (0)