Skip to content

Commit ff6065d

Browse files
committed
Reuse existing quantization file
1 parent 0360e75 commit ff6065d

File tree

1 file changed

+34
-23
lines changed

1 file changed

+34
-23
lines changed

model_navigator/commands/convert/converters/onnx2trt.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,30 @@ def _get_precisions(precision, precision_mode):
6262
return tf32, fp16, bf16, fp8, int8, nvfp4
6363

6464

65+
def _quantize_model(
66+
navigator_workspace: pathlib.Path,
67+
batch_dim: int,
68+
quantized_onnx_path: pathlib.Path,
69+
onnx_path: pathlib.Path,
70+
precision: str,
71+
):
72+
import modelopt.onnx.quantization as moq # pytype: disable=import-error # noqa: F401
73+
74+
correctness_samples = load_samples("correctness_samples", navigator_workspace, batch_dim)
75+
calibration_data = {name: tensor for sample in correctness_samples for name, tensor in sample.items()}
76+
77+
# Prepare quantization parameters
78+
quantize_kwargs = {
79+
"onnx_path": onnx_path.as_posix(),
80+
"calibration_data": calibration_data,
81+
"output_path": quantized_onnx_path.as_posix(),
82+
"quantize_mode": precision,
83+
}
84+
85+
moq.quantize(**quantize_kwargs)
86+
LOGGER.info("Quantized ONNX model saved in {}", quantized_onnx_path)
87+
88+
6589
def convert(
6690
exported_model_path: str,
6791
converted_model_path: str,
@@ -111,7 +135,6 @@ def convert(
111135
exported_model_path = pathlib.Path(exported_model_path)
112136
if not exported_model_path.is_absolute():
113137
exported_model_path = navigator_workspace / exported_model_path
114-
exported_model_path = exported_model_path.as_posix()
115138

116139
if model_name is None:
117140
model_name = navigator_workspace.stem
@@ -125,7 +148,6 @@ def convert(
125148
quantized_onnx_path = pathlib.Path(quantized_onnx_path)
126149
if not quantized_onnx_path.is_absolute():
127150
quantized_onnx_path = navigator_workspace / quantized_onnx_path
128-
quantized_onnx_path = quantized_onnx_path.as_posix()
129151

130152
custom_args = custom_args or {}
131153

@@ -140,9 +162,10 @@ def convert(
140162

141163
# nvfp4 is currently not used as flag for converter, skip it
142164
tf32, fp16, bf16, fp8, int8, _ = _get_precisions(precision, precision_mode)
165+
strongly_typed = False
143166

144167
# Determine the path to use for ONNX model
145-
onnx_path = exported_model_path
168+
onnx_path = pathlib.Path(exported_model_path)
146169

147170
# Check if we need to perform quantization
148171
should_quantize = (
@@ -154,31 +177,19 @@ def convert(
154177
# Use ModelOpt for quantization if needed
155178
if quantized_onnx_path and should_quantize:
156179
LOGGER.info("Quantize model through TensorRT ModelOpt with {} precision", precision)
157-
import modelopt.onnx.quantization as moq # pytype: disable=import-error # noqa: F401
158180

159-
correctness_samples = load_samples("correctness_samples", navigator_workspace, batch_dim)
160-
calibration_data = {name: tensor for sample in correctness_samples for name, tensor in sample.items()}
181+
if not pathlib.Path(quantized_onnx_path).exists():
182+
_quantize_model(navigator_workspace, batch_dim, quantized_onnx_path, onnx_path, precision)
183+
else:
184+
LOGGER.info("Quantized ONNX model already exists in {}", quantized_onnx_path)
161185

162-
# Prepare quantization parameters
163-
quantize_kwargs = {
164-
"onnx_path": onnx_path,
165-
"calibration_data": calibration_data,
166-
"output_path": quantized_onnx_path,
167-
"quantize_mode": precision,
168-
}
169-
170-
moq.quantize(**quantize_kwargs)
171-
LOGGER.info("Quantized ONNX model saved in {}", quantized_onnx_path)
172-
onnx_path = quantized_onnx_path
186+
onnx_path = pathlib.Path(quantized_onnx_path)
173187
# For NVFP4, always use the quantized path (even if not quantized yet)
174188
elif quantized_onnx_path and TensorRTPrecision(precision) == TensorRTPrecision.NVFP4:
175-
onnx_path = quantized_onnx_path
176-
177-
if TensorRTPrecision(precision) == TensorRTPrecision.NVFP4:
178189
strongly_typed = True
179-
else:
180-
strongly_typed = False
181-
network = network_from_onnx_path(onnx_path, flags=onnx_parser_flags, strongly_typed=strongly_typed)
190+
onnx_path = pathlib.Path(quantized_onnx_path)
191+
192+
network = network_from_onnx_path(onnx_path.as_posix(), flags=onnx_parser_flags, strongly_typed=strongly_typed)
182193

183194
config_kwargs = {}
184195
if optimization_level:

0 commit comments

Comments
 (0)