@@ -33,33 +33,23 @@ def _get_precisions(precision, precision_mode):
3333 precision = TensorRTPrecision (precision )
3434 precision_mode = TensorRTPrecisionMode (precision_mode )
3535 if precision_mode == TensorRTPrecisionMode .HIERARCHY :
36- tf32 , fp16 , bf16 , fp8 , int8 , nvfp4 = {
37- # TODO: Enable hierarchical BF16 for FP16, FP8 and INT8 after it's supported
38- TensorRTPrecision .FP32 : [True , False , False , False , False , False ],
39- # TensorRTPrecision.FP16: [True, True, True, False, False],
40- TensorRTPrecision .FP16 : [True , True , False , False , False , False ],
41- TensorRTPrecision .BF16 : [True , True , True , False , False , False ],
42- # TensorRTPrecision.FP8: [True, True, True, True, False],
43- TensorRTPrecision .FP8 : [True , True , False , True , False , False ],
44- # TensorRTPrecision.INT8: [True, True, True, False, True],
45- TensorRTPrecision .INT8 : [True , True , False , False , True , False ],
46- TensorRTPrecision .NVFP4 : [True , True , False , False , False , True ],
36+ tf32 , fp16 , bf16 = {
37+ TensorRTPrecision .FP32 : [True , False , False ],
38+ TensorRTPrecision .FP16 : [True , True , False ],
39+ TensorRTPrecision .BF16 : [True , True , True ],
4740 }[precision ]
4841 elif precision_mode == TensorRTPrecisionMode .SINGLE :
49- tf32 , fp16 , bf16 , fp8 , int8 , nvfp4 = {
50- TensorRTPrecision .FP32 : [True , False , False , False , False , False ],
51- TensorRTPrecision .FP16 : [False , True , False , False , False , False ],
52- TensorRTPrecision .BF16 : [False , False , True , False , False , False ],
53- TensorRTPrecision .FP8 : [False , False , False , True , False , False ],
54- TensorRTPrecision .INT8 : [False , False , False , False , True , False ],
55- TensorRTPrecision .NVFP4 : [False , False , False , False , False , True ],
42+ tf32 , fp16 , bf16 = {
43+ TensorRTPrecision .FP32 : [True , False , False ],
44+ TensorRTPrecision .FP16 : [False , True , False ],
45+ TensorRTPrecision .BF16 : [False , False , True ],
5646 }[precision ]
5747 else :
5848 raise ValueError (
5949 f"Unsupported precision mode { precision_mode } . Only { TensorRTPrecisionMode .HIERARCHY } and "
6050 f"{ TensorRTPrecisionMode .SINGLE } are allowed"
6151 )
62- return tf32 , fp16 , bf16 , fp8 , int8 , nvfp4
52+ return tf32 , fp16 , bf16
6353
6454
6555def _quantize_model (
@@ -86,6 +76,41 @@ def _quantize_model(
8676 LOGGER .info ("Quantized ONNX model saved in {}" , quantized_onnx_path )
8777
8878
79+ def _build_create_config_kwargs (
80+ max_workspace_size ,
81+ precision ,
82+ precision_mode ,
83+ optimization_level ,
84+ compatibility_level ,
85+ custom_args ,
86+ trt_profiles ,
87+ timing_cache ,
88+ ):
89+ create_config_kwargs = {
90+ "profiles" : trt_profiles ,
91+ "load_timing_cache" : timing_cache ,
92+ ** custom_args ,
93+ }
94+ tf32 , fp16 , bf16 = _get_precisions (precision , precision_mode )
95+
96+ if optimization_level :
97+ create_config_kwargs ["builder_optimization_level" ] = optimization_level
98+ if compatibility_level :
99+ create_config_kwargs ["hardware_compatibility_level" ] = compatibility_level
100+
101+ if max_workspace_size :
102+ create_config_kwargs ["memory_pool_limits" ] = {
103+ trt .MemoryPoolType .WORKSPACE : max_workspace_size ,
104+ }
105+
106+ # Set precision-specific flags
107+ if TensorRTPrecision (precision ) not in (TensorRTPrecision .INT8 , TensorRTPrecision .FP8 , TensorRTPrecision .NVFP4 ):
108+ create_config_kwargs ["tf32" ] = tf32
109+ create_config_kwargs ["fp16" ] = fp16
110+ create_config_kwargs ["bf16" ] = bf16
111+ return create_config_kwargs
112+
113+
89114def convert (
90115 exported_model_path : str ,
91116 converted_model_path : str ,
@@ -160,8 +185,6 @@ def convert(
160185 if not trt_profiles :
161186 trt_profiles = [Profile ()]
162187
163- # nvfp4 is currently not used as flag for converter, skip it
164- tf32 , fp16 , bf16 , fp8 , int8 , _ = _get_precisions (precision , precision_mode )
165188 strongly_typed = False
166189
167190 # Determine the path to use for ONNX model
@@ -186,21 +209,12 @@ def convert(
186209 onnx_path = pathlib .Path (quantized_onnx_path )
187210 # For NVFP4, always use the quantized path (even if not quantized yet)
188211 elif quantized_onnx_path and TensorRTPrecision (precision ) == TensorRTPrecision .NVFP4 :
189- strongly_typed = True
190212 onnx_path = pathlib .Path (quantized_onnx_path )
191213
192- network = network_from_onnx_path (onnx_path .as_posix (), flags = onnx_parser_flags , strongly_typed = strongly_typed )
193-
194- config_kwargs = {}
195- if optimization_level :
196- config_kwargs ["builder_optimization_level" ] = optimization_level
197- if compatibility_level :
198- config_kwargs ["hardware_compatibility_level" ] = compatibility_level
214+ if TensorRTPrecision (precision ) in (TensorRTPrecision .INT8 , TensorRTPrecision .FP8 , TensorRTPrecision .NVFP4 ):
215+ strongly_typed = True
199216
200- if max_workspace_size :
201- config_kwargs ["memory_pool_limits" ] = {
202- trt .MemoryPoolType .WORKSPACE : max_workspace_size ,
203- }
217+ network = network_from_onnx_path (onnx_path .as_posix (), flags = onnx_parser_flags , strongly_typed = strongly_typed )
204218
205219 # saving timing cache in model_navigator workspace or ...
206220 timing_cache = trt_cache_inplace_cache_dir ()
@@ -210,19 +224,20 @@ def convert(
210224 with TimingCacheManager (model_name = model_name , cache_path = timing_cache ) as timing_cache :
211225 timing_cache = timing_cache .as_posix () if timing_cache else None
212226
227+ create_config_kwargs = _build_create_config_kwargs (
228+ max_workspace_size ,
229+ precision ,
230+ precision_mode ,
231+ optimization_level ,
232+ compatibility_level ,
233+ custom_args ,
234+ trt_profiles ,
235+ timing_cache ,
236+ )
237+
213238 engine = engine_from_network (
214239 network ,
215- config = CreateConfig (
216- tf32 = tf32 ,
217- fp16 = fp16 ,
218- bf16 = bf16 ,
219- fp8 = fp8 ,
220- int8 = int8 ,
221- profiles = trt_profiles ,
222- load_timing_cache = timing_cache ,
223- ** config_kwargs ,
224- ** custom_args ,
225- ),
240+ config = CreateConfig (** create_config_kwargs ),
226241 save_timing_cache = timing_cache ,
227242 )
228243 save_engine (engine , path = converted_model_path )
0 commit comments