Skip to content

Commit 190e428

Browse files
committed
Add calibrator option to add extra dtypes
1 parent 7845828 commit 190e428

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

onnxruntime/python/tools/quantization/calibrate.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
augmented_model_path="augmented_model.onnx",
116116
symmetric=False,
117117
use_external_data_format=False,
118+
data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT],
118119
):
119120
"""
120121
:param model_path: ONNX model to calibrate. It should be a model file path
@@ -138,6 +139,7 @@ def __init__(
138139
self.augment_model = None
139140
self.infer_session = None
140141
self.execution_providers = ["CPUExecutionProvider"]
142+
self.tensor_types_to_calibrate = data_types_to_calibrate
141143

142144
def set_execution_providers(self, execution_providers=["CPUExecutionProvider"]): # noqa: B006
143145
"""
@@ -171,7 +173,6 @@ def select_tensors_to_calibrate(self, model: ModelProto):
171173
initializer = {init.name for init in model.graph.initializer}
172174

173175
tensors_to_calibrate = set()
174-
tensor_type_to_calibrate = {TensorProto.FLOAT}
175176

176177
for node in model.graph.node:
177178
if not self.op_types_to_calibrate or node.op_type in self.op_types_to_calibrate:
@@ -180,7 +181,7 @@ def select_tensors_to_calibrate(self, model: ModelProto):
180181
vi = value_infos[tensor_name]
181182
if (
182183
vi.type.HasField("tensor_type")
183-
and (vi.type.tensor_type.elem_type in tensor_type_to_calibrate)
184+
and (vi.type.tensor_type.elem_type in self.tensor_types_to_calibrate)
184185
and (tensor_name not in initializer)
185186
):
186187
tensors_to_calibrate.add(tensor_name)
@@ -224,6 +225,7 @@ def __init__(
224225
use_external_data_format=False,
225226
moving_average=False,
226227
averaging_constant=0.01,
228+
data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT],
227229
):
228230
"""
229231
:param model_path: ONNX model to calibrate. It is a model path
@@ -240,6 +242,7 @@ def __init__(
240242
augmented_model_path=augmented_model_path,
241243
symmetric=symmetric,
242244
use_external_data_format=use_external_data_format,
245+
data_types_to_calibrate=data_types_to_calibrate,
243246
)
244247
self.intermediate_outputs = []
245248
self.calibrate_tensors_range = None
@@ -256,7 +259,7 @@ def augment_graph(self):
256259
model and ensures their outputs are stored as part of the graph output
257260
:return: augmented ONNX model
258261
"""
259-
tensors, _ = self.select_tensors_to_calibrate(self.model)
262+
tensors, value_infos = self.select_tensors_to_calibrate(self.model)
260263
reshape_shape_name = str(uuid.uuid4())
261264
reshape_shape = numpy_helper.from_array(np.array([1], dtype=np.int64), reshape_shape_name)
262265
self.model.graph.initializer.append(reshape_shape)
@@ -280,8 +283,10 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
280283
name=intermediate_output,
281284
)
282285

286+
out_dtype = value_infos[tensor].type.tensor_type.elem_type
287+
283288
self.model.graph.node.extend([reduce_node, reshape_node])
284-
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, TensorProto.FLOAT, [1]))
289+
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, out_dtype, [1]))
285290

286291
for tensor in tensors:
287292
add_reduce_min_max(tensor, "ReduceMin")
@@ -396,6 +401,7 @@ def __init__(
396401
num_quantized_bins=2048,
397402
percentile=99.999,
398403
scenario="same",
404+
data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT]
399405
):
400406
"""
401407
:param model_path: ONNX model to calibrate. It is a model path.
@@ -415,6 +421,7 @@ def __init__(
415421
augmented_model_path=augmented_model_path,
416422
symmetric=symmetric,
417423
use_external_data_format=use_external_data_format,
424+
data_types_to_calibrate=data_types_to_calibrate,
418425
)
419426
self.intermediate_outputs = []
420427
self.calibrate_tensors_range = None
@@ -515,6 +522,7 @@ def __init__(
515522
symmetric=False,
516523
num_bins=128,
517524
num_quantized_bins=128,
525+
data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT],
518526
):
519527
"""
520528
:param model_path: ONNX model to calibrate. It is a model path
@@ -535,6 +543,7 @@ def __init__(
535543
symmetric=symmetric,
536544
num_bins=num_bins,
537545
num_quantized_bins=num_quantized_bins,
546+
data_types_to_calibrate=data_types_to_calibrate,
538547
)
539548

540549

@@ -549,6 +558,7 @@ def __init__(
549558
symmetric=False,
550559
num_bins=2048,
551560
percentile=99.999,
561+
data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT],
552562
):
553563
"""
554564
:param model_path: ONNX model to calibrate. It is a model path
@@ -569,6 +579,7 @@ def __init__(
569579
symmetric=symmetric,
570580
num_bins=num_bins,
571581
percentile=percentile,
582+
data_types_to_calibrate=data_types_to_calibrate,
572583
)
573584

574585

@@ -582,6 +593,7 @@ def __init__(
582593
method="distribution",
583594
num_bins=128,
584595
scenario="same",
596+
data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT],
585597
):
586598
"""
587599
:param model_path: ONNX model to calibrate. It is a model path
@@ -604,6 +616,7 @@ def __init__(
604616
method=method,
605617
num_bins=num_bins,
606618
scenario=scenario,
619+
data_types_to_calibrate=data_types_to_calibrate,
607620
)
608621

609622

@@ -1004,6 +1017,7 @@ def create_calibrator(
10041017
calibrate_method=CalibrationMethod.MinMax,
10051018
use_external_data_format=False,
10061019
extra_options={}, # noqa: B006
1020+
data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT],
10071021
):
10081022
calibrator = None
10091023
if calibrate_method == CalibrationMethod.MinMax:
@@ -1019,6 +1033,7 @@ def create_calibrator(
10191033
symmetric=symmetric,
10201034
moving_average=moving_average,
10211035
averaging_constant=averaging_constant,
1036+
data_types_to_calibrate=data_types_to_calibrate,
10221037
)
10231038
elif calibrate_method == CalibrationMethod.Entropy:
10241039
# default settings for entropy algorithm
@@ -1033,6 +1048,7 @@ def create_calibrator(
10331048
symmetric=symmetric,
10341049
num_bins=num_bins,
10351050
num_quantized_bins=num_quantized_bins,
1051+
data_types_to_calibrate=data_types_to_calibrate,
10361052
)
10371053
elif calibrate_method == CalibrationMethod.Percentile:
10381054
# default settings for percentile algorithm
@@ -1047,6 +1063,7 @@ def create_calibrator(
10471063
symmetric=symmetric,
10481064
num_bins=num_bins,
10491065
percentile=percentile,
1066+
data_types_to_calibrate=data_types_to_calibrate,
10501067
)
10511068

10521069
elif calibrate_method == CalibrationMethod.Distribution:
@@ -1061,6 +1078,7 @@ def create_calibrator(
10611078
use_external_data_format=use_external_data_format,
10621079
num_bins=num_bins,
10631080
scenario=scenario,
1081+
data_types_to_calibrate=data_types_to_calibrate,
10641082
)
10651083

10661084
if calibrator:

0 commit comments

Comments
 (0)