Skip to content

Commit ae1e52e

Browse files
committed
fix model builder
1 parent 5a7b27c commit ae1e52e

File tree

4 files changed

+62
-97
lines changed

4 files changed

+62
-97
lines changed

_unittests/ut_helpers/test_doc_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_custom_doc_kernels_layer_normalization(self):
5656
)
5757
expected = torch_sess.run(None, feeds)
5858
got = torch_sess_custom.run(None, feeds)
59-
self.assertEqualAny(expected, got)
59+
self.assertEqualAny(expected, got, atol=1e-3)
6060

6161
def test_custom_doc_kernels_matmul(self):
6262
model = oh.make_model(

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import unittest
32
from onnx_diagnostic.ext_test_case import (
43
ExtTestCase,
@@ -48,32 +47,17 @@ def test_model_builder_id(self):
4847
cache_dir=folder,
4948
verbose=1,
5049
)
51-
self.assertGreater(len(onnx_model.nodes), 5)
50+
self.assertGreater(onnx_model.model.graph.num_nodes(), 5)
51+
model_name = save_model_builder(onnx_model, folder, verbose=1)
52+
self.assertExists(model_name)
5253

53-
proto = save_model_builder(onnx_model, verbose=1)
5454
import onnxruntime
5555

56-
onnxruntime.InferenceSession(
57-
proto.SerializeToString(), providers=["CPUExecutionProvider"]
58-
)
59-
60-
# We need to start again.
61-
onnx_model = create_model_builder(
62-
data["configuration"],
63-
data["model"],
64-
precision="fp32",
65-
execution_provider="cpu",
66-
cache_dir=folder,
67-
verbose=1,
68-
)
69-
save_model_builder(onnx_model, folder, verbose=1)
70-
model_name = os.path.join(folder, "model.onnx")
71-
self.assertExists(model_name)
72-
73-
feeds = make_feeds(proto, data["inputs"], use_numpy=True)
56+
sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"])
57+
del data["inputs"]["position_ids"]
58+
feeds = make_feeds([i.name for i in sess.get_inputs()], data["inputs"], use_numpy=True)
7459
expected = data["model"](**data["inputs"])
7560

76-
sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"])
7761
try:
7862
got = sess.run(None, feeds)
7963
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument as e:

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 54 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import requests
44
import sys
55
from pathlib import Path
6-
from typing import Any, Optional
6+
from typing import Any, Optional, Union
77
from urllib.parse import urlparse
8-
from onnx import helper, save_model, external_data_helper, ModelProto
8+
from onnx import ModelProto, TensorProto
99

1010
CACHE_SUBDIR = "onnx-diagnostic"
1111

@@ -114,87 +114,58 @@ def _make_model(self, model, verbose: int = 0):
114114
self.make_lm_head(module)
115115

116116

117-
def save_model_builder(self, out_dir: Optional[str] = "", verbose: int = 0) -> ModelProto:
117+
def save_model_builder(
118+
self, out_dir: Optional[str] = "", verbose: int = 0
119+
) -> Union[str, ModelProto]:
118120
"""
119121
Saves a model created by function :func:`create_model_builder`.
120122
If out_dir is empty or not specified, the function still returns the
121123
generated model.
122124
"""
123-
if verbose:
124-
print(f"[save_model_builder] Saving ONNX model in {out_dir}")
125-
126-
# Create ONNX model
127-
model = helper.make_model(
128-
opset_imports=[
129-
self.clear_field(
130-
helper.make_operatorsetid("", 21 if self.quant_attrs["use_qdq"] else 14),
131-
"domain",
132-
),
133-
helper.make_operatorsetid("com.microsoft", 1),
134-
],
135-
ir_version=7,
136-
producer_name="onnxruntime-genai",
137-
producer_version="0.0.0",
138-
graph=self.make_graph(
139-
name="main_graph",
140-
inputs=self.inputs,
141-
outputs=self.outputs,
142-
initializer=self.initializers,
143-
value_info=self.value_infos,
144-
nodes=self.nodes,
145-
),
146-
)
147-
148-
# Load external data into ONNX model
149-
external_data_helper.load_external_data_for_model(model, self.cache_dir)
150-
151-
# Delete external data files on disk before re-saving
152-
for path in os.listdir(self.cache_dir):
153-
if path.endswith(".bin"):
154-
os.remove(os.path.join(self.cache_dir, path))
125+
import onnx_ir
155126

156-
# Delete temporary cache dir if empty
157-
# if len(os.listdir(self.cache_dir)) == 0:
158-
# os.rmdir(self.cache_dir)
127+
if verbose:
128+
print(f"[save_model_builder] Saving ONNX model in {out_dir!r}")
159129

160-
# Quantize ONNX model to desired precision
130+
# Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
161131
already_quantized_in_qdq_format = (
162132
self.quant_type is not None and self.quant_attrs["use_qdq"]
163-
) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
164-
if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format:
165-
model = self.to_int4(model)
133+
)
134+
model = (
135+
self.to_int4()
136+
if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4}
137+
and not already_quantized_in_qdq_format
138+
else self.model
139+
)
140+
model.graph.sort()
141+
if not out_dir:
142+
return onnx_ir.to_proto(model)
166143

167-
# Save ONNX model with only one external data file and delete any existing duplicate copies
168-
if out_dir:
169-
out_path = os.path.join(out_dir, self.filename)
170-
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
171-
if os.path.exists(out_path):
172-
if verbose:
173-
print(f"[save_model_builder] Overwriting {out_path!r}")
174-
os.remove(out_path)
175-
if os.path.exists(data_path):
176-
if verbose:
177-
print(f"[save_model_builder] Overwriting {data_path!r}")
178-
os.remove(data_path)
144+
out_path = os.path.join(out_dir, self.filename)
145+
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
179146

180-
if out_dir:
181-
location = os.path.basename(data_path)
182-
if os.path.exists(location):
183-
os.remove(location)
147+
# Save ONNX model with only one external data file and delete any existing duplicate copies
148+
out_path = os.path.join(out_dir, self.filename)
149+
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
150+
if os.path.exists(out_path):
184151
if verbose:
185-
print(f"[save_model_builder] out_path={out_path!r}")
186-
print(f"[save_model_builder] location={location!r}")
187-
save_model(
188-
model,
189-
out_path,
190-
save_as_external_data=True,
191-
all_tensors_to_one_file=True,
192-
location=location,
193-
size_threshold=1024,
194-
convert_attribute=False,
195-
)
196-
return None
197-
return model
152+
print(f"[save_model_builder] Overwriting {out_path!r}")
153+
os.remove(out_path)
154+
if os.path.exists(data_path):
155+
if verbose:
156+
print(f"[save_model_builder] Overwriting {data_path!r}")
157+
os.remove(data_path)
158+
159+
onnx_ir.save(
160+
model,
161+
out_path,
162+
external_data=os.path.basename(data_path),
163+
size_threshold_bytes=2**10,
164+
)
165+
if verbose:
166+
print(f"[save_model_builder] saved in {out_dir!r}")
167+
168+
return out_path
198169

199170

200171
def create_model_builder(
@@ -335,13 +306,23 @@ def _post(onnx_model):
335306
for c in remove:
336307
delattr(config, c)
337308

338-
onnx_model = cls(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
309+
convert = {
310+
"fp32": TensorProto.FLOAT,
311+
"fp16": TensorProto.FLOAT16,
312+
"bfp16": TensorProto.BFLOAT16,
313+
}
314+
assert (
315+
precision in convert
316+
), f"Unexpected value for precision={precision!r}, should be in {convert}"
317+
onnx_model = cls(
318+
config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options
319+
)
339320

340321
if post:
341322
post(onnx_model)
342323
_make_model(onnx_model, model, verbose=verbose)
343324

344-
assert onnx_model.nodes, (
325+
assert onnx_model.model, (
345326
f"No node in the model, io_dtype={io_dtype!r}, "
346327
f"precision={precision!r}, execution_provider={execution_provider!r}, "
347328
f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, "

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def call_torch_export_export(
10031003
if "ERR_export_export" in summary:
10041004
return summary, data
10051005

1006-
disc = max_diff(data["expected"], expected)
1006+
disc = max_diff(data["run_expected"], expected)
10071007
for k, v in disc.items():
10081008
summary[f"disc_exported_{k}"] = str(v)
10091009
if verbose:

0 commit comments

Comments
 (0)