@@ -115,6 +115,7 @@ def __init__(
115115 augmented_model_path = "augmented_model.onnx" ,
116116 symmetric = False ,
117117 use_external_data_format = False ,
118+ data_types_to_calibrate : list [TensorProto .DataType ] = [TensorProto .FLOAT ],
118119 ):
119120 """
120121 :param model_path: ONNX model to calibrate. It should be a model file path
@@ -138,6 +139,7 @@ def __init__(
138139 self .augment_model = None
139140 self .infer_session = None
140141 self .execution_providers = ["CPUExecutionProvider" ]
142+ self .tensor_types_to_calibrate = data_types_to_calibrate
141143
142144 def set_execution_providers (self , execution_providers = ["CPUExecutionProvider" ]): # noqa: B006
143145 """
@@ -171,7 +173,6 @@ def select_tensors_to_calibrate(self, model: ModelProto):
171173 initializer = {init .name for init in model .graph .initializer }
172174
173175 tensors_to_calibrate = set ()
174- tensor_type_to_calibrate = {TensorProto .FLOAT }
175176
176177 for node in model .graph .node :
177178 if not self .op_types_to_calibrate or node .op_type in self .op_types_to_calibrate :
@@ -180,7 +181,7 @@ def select_tensors_to_calibrate(self, model: ModelProto):
180181 vi = value_infos [tensor_name ]
181182 if (
182183 vi .type .HasField ("tensor_type" )
183- and (vi .type .tensor_type .elem_type in tensor_type_to_calibrate )
184+ and (vi .type .tensor_type .elem_type in self . tensor_types_to_calibrate )
184185 and (tensor_name not in initializer )
185186 ):
186187 tensors_to_calibrate .add (tensor_name )
@@ -224,6 +225,7 @@ def __init__(
224225 use_external_data_format = False ,
225226 moving_average = False ,
226227 averaging_constant = 0.01 ,
228+ data_types_to_calibrate : list [TensorProto .DataType ] = [TensorProto .FLOAT ],
227229 ):
228230 """
229231 :param model_path: ONNX model to calibrate. It is a model path
@@ -240,6 +242,7 @@ def __init__(
240242 augmented_model_path = augmented_model_path ,
241243 symmetric = symmetric ,
242244 use_external_data_format = use_external_data_format ,
245+ data_types_to_calibrate = data_types_to_calibrate ,
243246 )
244247 self .intermediate_outputs = []
245248 self .calibrate_tensors_range = None
@@ -256,7 +259,7 @@ def augment_graph(self):
256259 model and ensures their outputs are stored as part of the graph output
257260 :return: augmented ONNX model
258261 """
259- tensors , _ = self .select_tensors_to_calibrate (self .model )
262+ tensors , value_infos = self .select_tensors_to_calibrate (self .model )
260263 reshape_shape_name = str (uuid .uuid4 ())
261264 reshape_shape = numpy_helper .from_array (np .array ([1 ], dtype = np .int64 ), reshape_shape_name )
262265 self .model .graph .initializer .append (reshape_shape )
@@ -280,8 +283,10 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
280283 name = intermediate_output ,
281284 )
282285
286+ out_dtype = value_infos [tensor ].type .tensor_type .elem_type
287+
283288 self .model .graph .node .extend ([reduce_node , reshape_node ])
284- self .model .graph .output .append (helper .make_tensor_value_info (reduce_output , TensorProto . FLOAT , [1 ]))
289+ self .model .graph .output .append (helper .make_tensor_value_info (reduce_output , out_dtype , [1 ]))
285290
286291 for tensor in tensors :
287292 add_reduce_min_max (tensor , "ReduceMin" )
@@ -396,6 +401,7 @@ def __init__(
396401 num_quantized_bins = 2048 ,
397402 percentile = 99.999 ,
398403 scenario = "same" ,
404+ data_types_to_calibrate : list [TensorProto .DataType ] = [TensorProto .FLOAT ]
399405 ):
400406 """
401407 :param model_path: ONNX model to calibrate. It is a model path.
@@ -415,6 +421,7 @@ def __init__(
415421 augmented_model_path = augmented_model_path ,
416422 symmetric = symmetric ,
417423 use_external_data_format = use_external_data_format ,
424+ data_types_to_calibrate = data_types_to_calibrate ,
418425 )
419426 self .intermediate_outputs = []
420427 self .calibrate_tensors_range = None
@@ -515,6 +522,7 @@ def __init__(
515522 symmetric = False ,
516523 num_bins = 128 ,
517524 num_quantized_bins = 128 ,
525+ data_types_to_calibrate : list [TensorProto ] = [TensorProto .FLOAT ],
518526 ):
519527 """
520528 :param model_path: ONNX model to calibrate. It is a model path
@@ -535,6 +543,7 @@ def __init__(
535543 symmetric = symmetric ,
536544 num_bins = num_bins ,
537545 num_quantized_bins = num_quantized_bins ,
546+ data_types_to_calibrate = data_types_to_calibrate ,
538547 )
539548
540549
@@ -549,6 +558,7 @@ def __init__(
549558 symmetric = False ,
550559 num_bins = 2048 ,
551560 percentile = 99.999 ,
561+ data_types_to_calibrate : list [TensorProto ] = [TensorProto .FLOAT ],
552562 ):
553563 """
554564 :param model_path: ONNX model to calibrate. It is a model path
@@ -569,6 +579,7 @@ def __init__(
569579 symmetric = symmetric ,
570580 num_bins = num_bins ,
571581 percentile = percentile ,
582+ data_types_to_calibrate = data_types_to_calibrate ,
572583 )
573584
574585
@@ -582,6 +593,7 @@ def __init__(
582593 method = "distribution" ,
583594 num_bins = 128 ,
584595 scenario = "same" ,
596+ data_types_to_calibrate : list [TensorProto ] = [TensorProto .FLOAT ],
585597 ):
586598 """
587599 :param model_path: ONNX model to calibrate. It is a model path
@@ -604,6 +616,7 @@ def __init__(
604616 method = method ,
605617 num_bins = num_bins ,
606618 scenario = scenario ,
619+ data_types_to_calibrate = data_types_to_calibrate ,
607620 )
608621
609622
@@ -1004,6 +1017,7 @@ def create_calibrator(
10041017 calibrate_method = CalibrationMethod .MinMax ,
10051018 use_external_data_format = False ,
10061019 extra_options = {}, # noqa: B006
1020+ data_types_to_calibrate : list [TensorProto .DataType ] = [TensorProto .FLOAT ],
10071021):
10081022 calibrator = None
10091023 if calibrate_method == CalibrationMethod .MinMax :
@@ -1019,6 +1033,7 @@ def create_calibrator(
10191033 symmetric = symmetric ,
10201034 moving_average = moving_average ,
10211035 averaging_constant = averaging_constant ,
1036+ data_types_to_calibrate = data_types_to_calibrate ,
10221037 )
10231038 elif calibrate_method == CalibrationMethod .Entropy :
10241039 # default settings for entropy algorithm
@@ -1033,6 +1048,7 @@ def create_calibrator(
10331048 symmetric = symmetric ,
10341049 num_bins = num_bins ,
10351050 num_quantized_bins = num_quantized_bins ,
1051+ data_types_to_calibrate = data_types_to_calibrate ,
10361052 )
10371053 elif calibrate_method == CalibrationMethod .Percentile :
10381054 # default settings for percentile algorithm
@@ -1047,6 +1063,7 @@ def create_calibrator(
10471063 symmetric = symmetric ,
10481064 num_bins = num_bins ,
10491065 percentile = percentile ,
1066+ data_types_to_calibrate = data_types_to_calibrate ,
10501067 )
10511068
10521069 elif calibrate_method == CalibrationMethod .Distribution :
@@ -1061,6 +1078,7 @@ def create_calibrator(
10611078 use_external_data_format = use_external_data_format ,
10621079 num_bins = num_bins ,
10631080 scenario = scenario ,
1081+ data_types_to_calibrate = data_types_to_calibrate ,
10641082 )
10651083
10661084 if calibrator :
0 commit comments