Skip to content

Commit 52cc3d2

Browse files
Fixes key error in precision flags helper function
1 parent cfe3b2f commit 52cc3d2

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

model_navigator/commands/convert/converters/onnx2trt.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,32 @@
3232
def _get_precisions(precision, precision_mode):
3333
precision = TensorRTPrecision(precision)
3434
precision_mode = TensorRTPrecisionMode(precision_mode)
35-
if precision_mode == TensorRTPrecisionMode.HIERARCHY:
36-
tf32, fp16, bf16 = {
37-
TensorRTPrecision.FP32: [True, False, False],
38-
TensorRTPrecision.FP16: [True, True, False],
39-
TensorRTPrecision.BF16: [True, True, True],
40-
}[precision]
41-
elif precision_mode == TensorRTPrecisionMode.SINGLE:
42-
tf32, fp16, bf16 = {
43-
TensorRTPrecision.FP32: [True, False, False],
44-
TensorRTPrecision.FP16: [False, True, False],
45-
TensorRTPrecision.BF16: [False, False, True],
46-
}[precision]
47-
else:
35+
36+
# Default values for quantized precisions
37+
if precision in (TensorRTPrecision.INT8, TensorRTPrecision.FP8, TensorRTPrecision.NVFP4):
38+
return False, False, False
39+
40+
# Precision configurations based on mode
41+
precision_configs = {
42+
TensorRTPrecisionMode.HIERARCHY: {
43+
TensorRTPrecision.FP32: (True, False, False),
44+
TensorRTPrecision.FP16: (True, True, False),
45+
TensorRTPrecision.BF16: (True, True, True),
46+
},
47+
TensorRTPrecisionMode.SINGLE: {
48+
TensorRTPrecision.FP32: (True, False, False),
49+
TensorRTPrecision.FP16: (False, True, False),
50+
TensorRTPrecision.BF16: (False, False, True),
51+
},
52+
}
53+
54+
if precision_mode not in precision_configs:
4855
raise ValueError(
4956
f"Unsupported precision mode {precision_mode}. Only {TensorRTPrecisionMode.HIERARCHY} and "
5057
f"{TensorRTPrecisionMode.SINGLE} are allowed"
5158
)
52-
return tf32, fp16, bf16
59+
60+
return precision_configs[precision_mode][precision]
5361

5462

5563
def _quantize_model(
@@ -91,7 +99,6 @@ def _build_create_config_kwargs(
9199
"load_timing_cache": timing_cache,
92100
**custom_args,
93101
}
94-
tf32, fp16, bf16 = _get_precisions(precision, precision_mode)
95102

96103
if optimization_level:
97104
create_config_kwargs["builder_optimization_level"] = optimization_level
@@ -105,9 +112,8 @@ def _build_create_config_kwargs(
105112

106113
# Set precision-specific flags
107114
if TensorRTPrecision(precision) not in (TensorRTPrecision.INT8, TensorRTPrecision.FP8, TensorRTPrecision.NVFP4):
108-
create_config_kwargs["tf32"] = tf32
109-
create_config_kwargs["fp16"] = fp16
110-
create_config_kwargs["bf16"] = bf16
115+
tf32, fp16, bf16 = _get_precisions(precision, precision_mode)
116+
create_config_kwargs.update({"tf32": tf32, "fp16": fp16, "bf16": bf16})
111117
return create_config_kwargs
112118

113119

0 commit comments

Comments
 (0)