Skip to content

Commit 58cc76c

Browse files
committed
bug
1 parent 64de535 commit 58cc76c

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.10
55
++++++
66

7+
* :pr:`384`: add ``weights_only=False`` when using :func:`torch.load`
8+
79
0.8.9
810
+++++
911

_doc/examples/plot_export_tiny_llm_method_generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def generate_text(
145145
# The following lines are a condensed copy with less comments.
146146

147147
# from HuggingFace
148+
print("----------------")
148149
MODEL_NAME = "arnir0/Tiny-LLM"
149150
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
150151
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
@@ -160,14 +161,15 @@ def generate_text(
160161
convert_after_n_calls=3,
161162
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
162163
)
164+
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
163165

164166
# from HuggingFace again
165-
prompt = "Continue: it rains..."
167+
prompt = "Continue: it rains, what should I do?"
166168
inputs = tokenizer(prompt, return_tensors="pt")
167169
outputs = model.generate(
168170
input_ids=inputs["input_ids"],
169171
attention_mask=inputs["attention_mask"],
170-
max_length=50,
172+
max_length=100,
171173
temperature=1,
172174
top_k=50,
173175
top_p=0.95,

onnx_diagnostic/export/api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,9 @@ def check_discrepancies(
738738
:param verbose: verbosity
739739
:return: results, a list of dictionaries, ready to be consumed by a dataframe
740740
"""
741-
assert self._export_done, "The onnx export was not done."
741+
assert (
742+
self._export_done
743+
), f"The onnx export was not done, only {len(self._inputs)} were stored."
742744
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
743745
assert os.path.exists(
744746
self._output_file
@@ -768,11 +770,11 @@ def check_discrepancies(
768770
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
769771
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
770772
with torch.serialization.safe_globals(classes):
771-
inputs = torch.load(self._input_file)
773+
inputs = torch.load(self._input_file, weights_only=False)
772774
if verbose:
773775
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
774776
with torch.serialization.safe_globals(classes):
775-
outputs = torch.load(self._output_file)
777+
outputs = torch.load(self._output_file, weights_only=False)
776778
assert len(inputs) == len(outputs), (
777779
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
778780
f"inputs={string_type(inputs, with_shape=True)}, "

0 commit comments

Comments
 (0)