|
3 | 3 | import requests |
4 | 4 | import sys |
5 | 5 | from pathlib import Path |
6 | | -from typing import Any, Optional |
| 6 | +from typing import Any, Optional, Union |
7 | 7 | from urllib.parse import urlparse |
8 | | -from onnx import helper, save_model, external_data_helper, ModelProto |
| 8 | +from onnx import ModelProto, TensorProto |
9 | 9 |
|
10 | 10 | CACHE_SUBDIR = "onnx-diagnostic" |
11 | 11 |
|
@@ -114,87 +114,58 @@ def _make_model(self, model, verbose: int = 0): |
114 | 114 | self.make_lm_head(module) |
115 | 115 |
|
116 | 116 |
|
117 | | -def save_model_builder(self, out_dir: Optional[str] = "", verbose: int = 0) -> ModelProto: |
| 117 | +def save_model_builder( |
| 118 | + self, out_dir: Optional[str] = "", verbose: int = 0 |
| 119 | +) -> Union[str, ModelProto]: |
118 | 120 | """ |
119 | 121 | Saves a model created by function :func:`create_model_builder`. |
120 | 122 | If out_dir is empty or not specified, the function still returns the |
121 | 123 | generated model. |
122 | 124 | """ |
123 | | - if verbose: |
124 | | - print(f"[save_model_builder] Saving ONNX model in {out_dir}") |
125 | | - |
126 | | - # Create ONNX model |
127 | | - model = helper.make_model( |
128 | | - opset_imports=[ |
129 | | - self.clear_field( |
130 | | - helper.make_operatorsetid("", 21 if self.quant_attrs["use_qdq"] else 14), |
131 | | - "domain", |
132 | | - ), |
133 | | - helper.make_operatorsetid("com.microsoft", 1), |
134 | | - ], |
135 | | - ir_version=7, |
136 | | - producer_name="onnxruntime-genai", |
137 | | - producer_version="0.0.0", |
138 | | - graph=self.make_graph( |
139 | | - name="main_graph", |
140 | | - inputs=self.inputs, |
141 | | - outputs=self.outputs, |
142 | | - initializer=self.initializers, |
143 | | - value_info=self.value_infos, |
144 | | - nodes=self.nodes, |
145 | | - ), |
146 | | - ) |
147 | | - |
148 | | - # Load external data into ONNX model |
149 | | - external_data_helper.load_external_data_for_model(model, self.cache_dir) |
150 | | - |
151 | | - # Delete external data files on disk before re-saving |
152 | | - for path in os.listdir(self.cache_dir): |
153 | | - if path.endswith(".bin"): |
154 | | - os.remove(os.path.join(self.cache_dir, path)) |
| 125 | + import onnx_ir |
155 | 126 |
|
156 | | - # Delete temporary cache dir if empty |
157 | | - # if len(os.listdir(self.cache_dir)) == 0: |
158 | | - # os.rmdir(self.cache_dir) |
| 127 | + if verbose: |
| 128 | + print(f"[save_model_builder] Saving ONNX model in {out_dir!r}") |
159 | 129 |
|
160 | | - # Quantize ONNX model to desired precision |
| 130 | + # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path |
161 | 131 | already_quantized_in_qdq_format = ( |
162 | 132 | self.quant_type is not None and self.quant_attrs["use_qdq"] |
163 | | - ) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path |
164 | | - if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format: |
165 | | - model = self.to_int4(model) |
| 133 | + ) |
| 134 | + model = ( |
| 135 | + self.to_int4() |
| 136 | + if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4} |
| 137 | + and not already_quantized_in_qdq_format |
| 138 | + else self.model |
| 139 | + ) |
| 140 | + model.graph.sort() |
| 141 | + if not out_dir: |
| 142 | + return onnx_ir.to_proto(model) |
166 | 143 |
|
167 | | - # Save ONNX model with only one external data file and delete any existing duplicate copies |
168 | | - if out_dir: |
169 | | - out_path = os.path.join(out_dir, self.filename) |
170 | | - data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data") |
171 | | - if os.path.exists(out_path): |
172 | | - if verbose: |
173 | | - print(f"[save_model_builder] Overwriting {out_path!r}") |
174 | | - os.remove(out_path) |
175 | | - if os.path.exists(data_path): |
176 | | - if verbose: |
177 | | - print(f"[save_model_builder] Overwriting {data_path!r}") |
178 | | - os.remove(data_path) |
| 144 | + out_path = os.path.join(out_dir, self.filename) |
| 145 | + data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data") |
179 | 146 |
|
180 | | - if out_dir: |
181 | | - location = os.path.basename(data_path) |
182 | | - if os.path.exists(location): |
183 | | - os.remove(location) |
| 147 | + # Save ONNX model with only one external data file and delete any existing duplicate copies |
| 148 | + out_path = os.path.join(out_dir, self.filename) |
| 149 | + data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data") |
| 150 | + if os.path.exists(out_path): |
184 | 151 | if verbose: |
185 | | - print(f"[save_model_builder] out_path={out_path!r}") |
186 | | - print(f"[save_model_builder] location={location!r}") |
187 | | - save_model( |
188 | | - model, |
189 | | - out_path, |
190 | | - save_as_external_data=True, |
191 | | - all_tensors_to_one_file=True, |
192 | | - location=location, |
193 | | - size_threshold=1024, |
194 | | - convert_attribute=False, |
195 | | - ) |
196 | | - return None |
197 | | - return model |
| 152 | + print(f"[save_model_builder] Overwriting {out_path!r}") |
| 153 | + os.remove(out_path) |
| 154 | + if os.path.exists(data_path): |
| 155 | + if verbose: |
| 156 | + print(f"[save_model_builder] Overwriting {data_path!r}") |
| 157 | + os.remove(data_path) |
| 158 | + |
| 159 | + onnx_ir.save( |
| 160 | + model, |
| 161 | + out_path, |
| 162 | + external_data=os.path.basename(data_path), |
| 163 | + size_threshold_bytes=2**10, |
| 164 | + ) |
| 165 | + if verbose: |
| 166 | + print(f"[save_model_builder] saved in {out_dir!r}") |
| 167 | + |
| 168 | + return out_path |
198 | 169 |
|
199 | 170 |
|
200 | 171 | def create_model_builder( |
@@ -335,13 +306,23 @@ def _post(onnx_model): |
335 | 306 | for c in remove: |
336 | 307 | delattr(config, c) |
337 | 308 |
|
338 | | - onnx_model = cls(config, io_dtype, precision, execution_provider, cache_dir, extra_options) |
| 309 | + convert = { |
| 310 | + "fp32": TensorProto.FLOAT, |
| 311 | + "fp16": TensorProto.FLOAT16, |
| 312 | + "bfp16": TensorProto.BFLOAT16, |
| 313 | + } |
| 314 | + assert ( |
| 315 | + precision in convert |
| 316 | + ), f"Unexpected value for precision={precision!r}, should be in {convert}" |
| 317 | + onnx_model = cls( |
| 318 | + config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options |
| 319 | + ) |
339 | 320 |
|
340 | 321 | if post: |
341 | 322 | post(onnx_model) |
342 | 323 | _make_model(onnx_model, model, verbose=verbose) |
343 | 324 |
|
344 | | - assert onnx_model.nodes, ( |
| 325 | + assert onnx_model.model, ( |
345 | 326 | f"No node in the model, io_dtype={io_dtype!r}, " |
346 | 327 | f"precision={precision!r}, execution_provider={execution_provider!r}, " |
347 | 328 | f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, " |
|
0 commit comments