Skip to content

Commit 712fb04

Browse files
Disable precision flags for quantized ONNX to TensorRT conversion
1 parent ff6065d commit 712fb04

File tree

1 file changed

+59
-44
lines changed

1 file changed

+59
-44
lines changed

model_navigator/commands/convert/converters/onnx2trt.py

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,23 @@ def _get_precisions(precision, precision_mode):
3333
precision = TensorRTPrecision(precision)
3434
precision_mode = TensorRTPrecisionMode(precision_mode)
3535
if precision_mode == TensorRTPrecisionMode.HIERARCHY:
36-
tf32, fp16, bf16, fp8, int8, nvfp4 = {
37-
# TODO: Enable hierarchical BF16 for FP16, FP8 and INT8 after it's supported
38-
TensorRTPrecision.FP32: [True, False, False, False, False, False],
39-
# TensorRTPrecision.FP16: [True, True, True, False, False],
40-
TensorRTPrecision.FP16: [True, True, False, False, False, False],
41-
TensorRTPrecision.BF16: [True, True, True, False, False, False],
42-
# TensorRTPrecision.FP8: [True, True, True, True, False],
43-
TensorRTPrecision.FP8: [True, True, False, True, False, False],
44-
# TensorRTPrecision.INT8: [True, True, True, False, True],
45-
TensorRTPrecision.INT8: [True, True, False, False, True, False],
46-
TensorRTPrecision.NVFP4: [True, True, False, False, False, True],
36+
tf32, fp16, bf16 = {
37+
TensorRTPrecision.FP32: [True, False, False],
38+
TensorRTPrecision.FP16: [True, True, False],
39+
TensorRTPrecision.BF16: [True, True, True],
4740
}[precision]
4841
elif precision_mode == TensorRTPrecisionMode.SINGLE:
49-
tf32, fp16, bf16, fp8, int8, nvfp4 = {
50-
TensorRTPrecision.FP32: [True, False, False, False, False, False],
51-
TensorRTPrecision.FP16: [False, True, False, False, False, False],
52-
TensorRTPrecision.BF16: [False, False, True, False, False, False],
53-
TensorRTPrecision.FP8: [False, False, False, True, False, False],
54-
TensorRTPrecision.INT8: [False, False, False, False, True, False],
55-
TensorRTPrecision.NVFP4: [False, False, False, False, False, True],
42+
tf32, fp16, bf16 = {
43+
TensorRTPrecision.FP32: [True, False, False],
44+
TensorRTPrecision.FP16: [False, True, False],
45+
TensorRTPrecision.BF16: [False, False, True],
5646
}[precision]
5747
else:
5848
raise ValueError(
5949
f"Unsupported precision mode {precision_mode}. Only {TensorRTPrecisionMode.HIERARCHY} and "
6050
f"{TensorRTPrecisionMode.SINGLE} are allowed"
6151
)
62-
return tf32, fp16, bf16, fp8, int8, nvfp4
52+
return tf32, fp16, bf16
6353

6454

6555
def _quantize_model(
@@ -86,6 +76,41 @@ def _quantize_model(
8676
LOGGER.info("Quantized ONNX model saved in {}", quantized_onnx_path)
8777

8878

79+
def _build_create_config_kwargs(
80+
max_workspace_size,
81+
precision,
82+
precision_mode,
83+
optimization_level,
84+
compatibility_level,
85+
custom_args,
86+
trt_profiles,
87+
timing_cache,
88+
):
89+
create_config_kwargs = {
90+
"profiles": trt_profiles,
91+
"load_timing_cache": timing_cache,
92+
**custom_args,
93+
}
94+
tf32, fp16, bf16 = _get_precisions(precision, precision_mode)
95+
96+
if optimization_level:
97+
create_config_kwargs["builder_optimization_level"] = optimization_level
98+
if compatibility_level:
99+
create_config_kwargs["hardware_compatibility_level"] = compatibility_level
100+
101+
if max_workspace_size:
102+
create_config_kwargs["memory_pool_limits"] = {
103+
trt.MemoryPoolType.WORKSPACE: max_workspace_size,
104+
}
105+
106+
# Set precision-specific flags
107+
if TensorRTPrecision(precision) not in (TensorRTPrecision.INT8, TensorRTPrecision.FP8, TensorRTPrecision.NVFP4):
108+
create_config_kwargs["tf32"] = tf32
109+
create_config_kwargs["fp16"] = fp16
110+
create_config_kwargs["bf16"] = bf16
111+
return create_config_kwargs
112+
113+
89114
def convert(
90115
exported_model_path: str,
91116
converted_model_path: str,
@@ -160,8 +185,6 @@ def convert(
160185
if not trt_profiles:
161186
trt_profiles = [Profile()]
162187

163-
# nvfp4 is currently not used as flag for converter, skip it
164-
tf32, fp16, bf16, fp8, int8, _ = _get_precisions(precision, precision_mode)
165188
strongly_typed = False
166189

167190
# Determine the path to use for ONNX model
@@ -186,21 +209,12 @@ def convert(
186209
onnx_path = pathlib.Path(quantized_onnx_path)
187210
# For NVFP4, always use the quantized path (even if not quantized yet)
188211
elif quantized_onnx_path and TensorRTPrecision(precision) == TensorRTPrecision.NVFP4:
189-
strongly_typed = True
190212
onnx_path = pathlib.Path(quantized_onnx_path)
191213

192-
network = network_from_onnx_path(onnx_path.as_posix(), flags=onnx_parser_flags, strongly_typed=strongly_typed)
193-
194-
config_kwargs = {}
195-
if optimization_level:
196-
config_kwargs["builder_optimization_level"] = optimization_level
197-
if compatibility_level:
198-
config_kwargs["hardware_compatibility_level"] = compatibility_level
214+
if TensorRTPrecision(precision) in (TensorRTPrecision.INT8, TensorRTPrecision.FP8, TensorRTPrecision.NVFP4):
215+
strongly_typed = True
199216

200-
if max_workspace_size:
201-
config_kwargs["memory_pool_limits"] = {
202-
trt.MemoryPoolType.WORKSPACE: max_workspace_size,
203-
}
217+
network = network_from_onnx_path(onnx_path.as_posix(), flags=onnx_parser_flags, strongly_typed=strongly_typed)
204218

205219
# saving timing cache in model_navigator workspace or ...
206220
timing_cache = trt_cache_inplace_cache_dir()
@@ -210,19 +224,20 @@ def convert(
210224
with TimingCacheManager(model_name=model_name, cache_path=timing_cache) as timing_cache:
211225
timing_cache = timing_cache.as_posix() if timing_cache else None
212226

227+
create_config_kwargs = _build_create_config_kwargs(
228+
max_workspace_size,
229+
precision,
230+
precision_mode,
231+
optimization_level,
232+
compatibility_level,
233+
custom_args,
234+
trt_profiles,
235+
timing_cache,
236+
)
237+
213238
engine = engine_from_network(
214239
network,
215-
config=CreateConfig(
216-
tf32=tf32,
217-
fp16=fp16,
218-
bf16=bf16,
219-
fp8=fp8,
220-
int8=int8,
221-
profiles=trt_profiles,
222-
load_timing_cache=timing_cache,
223-
**config_kwargs,
224-
**custom_args,
225-
),
240+
config=CreateConfig(**create_config_kwargs),
226241
save_timing_cache=timing_cache,
227242
)
228243
save_engine(engine, path=converted_model_path)

0 commit comments

Comments
 (0)