3232def _get_precisions (precision , precision_mode ):
3333 precision = TensorRTPrecision (precision )
3434 precision_mode = TensorRTPrecisionMode (precision_mode )
35- if precision_mode == TensorRTPrecisionMode .HIERARCHY :
36- tf32 , fp16 , bf16 = {
37- TensorRTPrecision .FP32 : [True , False , False ],
38- TensorRTPrecision .FP16 : [True , True , False ],
39- TensorRTPrecision .BF16 : [True , True , True ],
40- }[precision ]
41- elif precision_mode == TensorRTPrecisionMode .SINGLE :
42- tf32 , fp16 , bf16 = {
43- TensorRTPrecision .FP32 : [True , False , False ],
44- TensorRTPrecision .FP16 : [False , True , False ],
45- TensorRTPrecision .BF16 : [False , False , True ],
46- }[precision ]
47- else :
35+
36+ # Default values for quantized precisions
37+ if precision in (TensorRTPrecision .INT8 , TensorRTPrecision .FP8 , TensorRTPrecision .NVFP4 ):
38+ return False , False , False
39+
40+ # Precision configurations based on mode
41+ precision_configs = {
42+ TensorRTPrecisionMode .HIERARCHY : {
43+ TensorRTPrecision .FP32 : (True , False , False ),
44+ TensorRTPrecision .FP16 : (True , True , False ),
45+ TensorRTPrecision .BF16 : (True , True , True ),
46+ },
47+ TensorRTPrecisionMode .SINGLE : {
48+ TensorRTPrecision .FP32 : (True , False , False ),
49+ TensorRTPrecision .FP16 : (False , True , False ),
50+ TensorRTPrecision .BF16 : (False , False , True ),
51+ },
52+ }
53+
54+ if precision_mode not in precision_configs :
4855 raise ValueError (
4956 f"Unsupported precision mode { precision_mode } . Only { TensorRTPrecisionMode .HIERARCHY } and "
5057 f"{ TensorRTPrecisionMode .SINGLE } are allowed"
5158 )
52- return tf32 , fp16 , bf16
59+
60+ return precision_configs [precision_mode ][precision ]
5361
5462
5563def _quantize_model (
@@ -91,7 +99,6 @@ def _build_create_config_kwargs(
9199 "load_timing_cache" : timing_cache ,
92100 ** custom_args ,
93101 }
94- tf32 , fp16 , bf16 = _get_precisions (precision , precision_mode )
95102
96103 if optimization_level :
97104 create_config_kwargs ["builder_optimization_level" ] = optimization_level
@@ -105,9 +112,8 @@ def _build_create_config_kwargs(
105112
106113 # Set precision-specific flags
107114 if TensorRTPrecision (precision ) not in (TensorRTPrecision .INT8 , TensorRTPrecision .FP8 , TensorRTPrecision .NVFP4 ):
108- create_config_kwargs ["tf32" ] = tf32
109- create_config_kwargs ["fp16" ] = fp16
110- create_config_kwargs ["bf16" ] = bf16
115+ tf32 , fp16 , bf16 = _get_precisions (precision , precision_mode )
116+ create_config_kwargs .update ({"tf32" : tf32 , "fp16" : fp16 , "bf16" : bf16 })
111117 return create_config_kwargs
112118
113119
0 commit comments