Skip to content

Commit 44d9f6c

Browse files
authored
adding dequantizeLinearFixedPoint and quantizeLinearFIxedPoint (#40)
* adding demo script . adding to track progress, but will remove commit when ready to PR * all fixedpoint version * working fixedpoint version in python * saving * have python prototype working for dequantizeLinear * renamed file * workign tvm script for quantizeLinear * saving working python prototype of quantizeLinear op * first version of core implementation of dequantizeLinear * working version of dequantizeLienar contrip op * updated dequantizeLinearFixedPoint test * first working version of quantizeLinearFixedPoint contrib op * removing onnxruntime/test/python/gpnpumode/test_lutop.py * removing all python files for manual testing * changed roundToPosInf so that template arg becomes fucntion arg * using helper functions used by qlinearconv * removed debugging print statements and unnecessary comments * added test * cleaned up code * consistent casing * remove debug statements. fixed typo. removed unnecessary comments * adding python test for quantizeLinearFixedPoint * adding python test for dequantizeLinearFixedPoint * added include <vector> * corrected test that had overflow since input data provided was out of bounds for choice of frac bits * modified python test so taht it compares with float version of quantizelinear * modified python test for dequatnizeLinear so it compares with float version * just ran formatting on modified sections of code * updated wheel.yaml to install cmake<4. since the recently released cmake v4 is not yet supported by ort. * update computeFracBits to match tvm algo
1 parent a457271 commit 44d9f6c

File tree

10 files changed

+693
-85
lines changed

10 files changed

+693
-85
lines changed

.github/workflows/wheel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
- name: Build ONNX Runtime wheel
3434
working-directory: /workspace
3535
run: |
36-
python3 -m pip install cmake --upgrade
36+
python3 -m pip install "cmake<4"
3737
./build.sh --build_wheel --config Release --parallel ${{ github.event_name == 'pull_request' && ' ' || '--skip_tests'}} --skip_submodule_sync --allow_running_as_root --compile_no_warning_as_error
3838
wheel_path=$(find . -name '*.whl' | xargs readlink -f)
3939
echo "wheel_path=$wheel_path" >> $GITHUB_ENV

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu
153153
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor);
154154
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
155155

156+
// Quadric contrib ops
157+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, DequantizeLinearFixedPoint);
158+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuantizeLinearFixedPoint);
159+
156160
#ifdef ENABLE_ATEN
157161
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen);
158162
#endif
@@ -366,6 +370,9 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
366370
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
367371
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
368372
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,
373+
// Quadric contrib ops
374+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, DequantizeLinearFixedPoint)>,
375+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuantizeLinearFixedPoint)>,
369376

370377
#ifdef ENABLE_ATEN
371378
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#include "core/framework/op_kernel.h"
2+
#include "core/common/common.h"
3+
#include <cmath> // For log2()
4+
#include <limits> // For int8_t min/max
5+
#include <iostream>
6+
#include <iomanip> // For std::setprecision
7+
#include "core/mlas/inc/mlas.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
12+
// --- DequantizeLinearFixedPoint
13+
14+
class DequantizeLinearFixedPoint final : public OpKernel {
15+
public:
16+
explicit DequantizeLinearFixedPoint(const OpKernelInfo& info) : OpKernel(info) {}
17+
Status Compute(OpKernelContext* ctx) const override;
18+
};
19+
20+
// Register kernel
21+
ONNX_OPERATOR_KERNEL_EX(
22+
DequantizeLinearFixedPoint,
23+
kQuadricDomain, // Ensure this is defined in contrib_ops.h
24+
1,
25+
kCpuExecutionProvider,
26+
KernelDefBuilder()
27+
.TypeConstraint("T", DataTypeImpl::GetTensorType<int8_t>()) // Input tensor
28+
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()) // Scale
29+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>()) // Zero-point
30+
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()), // Output
31+
DequantizeLinearFixedPoint);
32+
33+
// Compute min/max range from scale & zero-point
34+
std::pair<float, float> getDequantizedRange(float scale, int8_t zeroPoint) {
35+
constexpr int8_t int8Min = std::numeric_limits<int8_t>::min();
36+
constexpr int8_t int8Max = std::numeric_limits<int8_t>::max();
37+
return {(int8Min - zeroPoint) * scale, (int8Max - zeroPoint) * scale};
38+
}
39+
40+
// Compute required fractional bits given a range
41+
int computeFracBits(float minVal, float maxVal) {
42+
constexpr int maxFracBits = 31;
43+
float absMinVal = std::fabs(minVal);
44+
float absMaxVal = std::fabs(maxVal);
45+
if (absMinVal > absMaxVal) {
46+
return (absMinVal < 1.0f) ? maxFracBits : (maxFracBits - static_cast<int>(std::ceil(std::log2(absMinVal))));
47+
} else {
48+
return (absMaxVal < 1.0f) ? maxFracBits : (maxFracBits - static_cast<int>(std::ceil(std::log2(absMaxVal + 1))));
49+
}
50+
}
51+
52+
// Fixed-point multiplication with provided shift
53+
int32_t fixedPointMultiply(int32_t a, int32_t b, int shift) {
54+
int64_t product = static_cast<int64_t>(a) * static_cast<int64_t>(b);
55+
return (shift > 0) ? (product >> shift) : (product << -shift);
56+
}
57+
58+
Status DequantizeLinearFixedPoint::Compute(OpKernelContext* ctx) const {
59+
// Retrieve input tensors
60+
const auto* X = ctx->Input<Tensor>(0);
61+
const auto* scale = ctx->Input<Tensor>(1);
62+
const auto* zeroPoint = ctx->Input<Tensor>(2);
63+
64+
// Validate inputs
65+
ORT_ENFORCE(X, "Input tensor 'X' is null.");
66+
ORT_ENFORCE(scale, "Scale tensor is null.");
67+
ORT_ENFORCE(zeroPoint, "Zero-point tensor is null.");
68+
69+
// Extract values
70+
const int8_t* xData = X->Data<int8_t>();
71+
float s = *(scale->Data<float>());
72+
int8_t zp = *(zeroPoint->Data<int8_t>());
73+
74+
// Compute range and fractional bits
75+
auto [minVal, maxVal] = getDequantizedRange(s, zp);
76+
int resultFracBits = computeFracBits(minVal, maxVal);
77+
78+
// Convert scale to fixed-point
79+
std::vector<double> scaleValueVec = {s};
80+
auto p = dataToQfp(scaleValueVec, -1, 32, false);
81+
int scaleFracBits = p.second;
82+
int32_t scaleQfp = static_cast<int32_t>(p.first[0]);
83+
84+
int shift = scaleFracBits - resultFracBits;
85+
86+
// Allocate output tensor
87+
auto* Y = ctx->Output(0, X->Shape());
88+
int32_t* yData = Y->MutableData<int32_t>();
89+
size_t tensorSize = X->Shape().Size();
90+
91+
for (size_t i = 0; i < tensorSize; ++i) {
92+
yData[i] = fixedPointMultiply(xData[i] - zp, scaleQfp, shift);
93+
}
94+
95+
return Status::OK();
96+
}
97+
98+
// --- QuantizeLinearFixedPoint
99+
class QuantizeLinearFixedPoint final : public OpKernel {
100+
public:
101+
explicit QuantizeLinearFixedPoint(const OpKernelInfo& info) : OpKernel(info) {}
102+
Status Compute(OpKernelContext* ctx) const override;
103+
};
104+
105+
// Register Kernel
106+
ONNX_OPERATOR_KERNEL_EX(
107+
QuantizeLinearFixedPoint,
108+
kQuadricDomain,
109+
1,
110+
kCpuExecutionProvider,
111+
KernelDefBuilder()
112+
.TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()) // Input tensor
113+
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>()) // xFracBits
114+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()) // Scale
115+
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int8_t>()) // Zero-point
116+
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int8_t>()), // Output
117+
QuantizeLinearFixedPoint);
118+
119+
Status QuantizeLinearFixedPoint::Compute(OpKernelContext* ctx) const {
120+
// Get input tensors
121+
const auto* X = ctx->Input<Tensor>(0);
122+
const auto* xFracBitsTensor = ctx->Input<Tensor>(1);
123+
const auto* scale = ctx->Input<Tensor>(2);
124+
const auto* zeroPoint = ctx->Input<Tensor>(3);
125+
126+
// Validate inputs
127+
ORT_ENFORCE(X != nullptr, "Input X is null");
128+
ORT_ENFORCE(xFracBitsTensor != nullptr, "xFracBits is null");
129+
ORT_ENFORCE(scale != nullptr, "Scale is null");
130+
ORT_ENFORCE(zeroPoint != nullptr, "Zero point is null");
131+
132+
// Retrieve input data
133+
const int32_t* x_data = X->Data<int32_t>();
134+
int8_t xFracBits = *(xFracBitsTensor->Data<int8_t>());
135+
double s = *(scale->Data<float>());
136+
int8_t zp = *(zeroPoint->Data<int8_t>());
137+
138+
double scaleInv = 1.0 / s;
139+
std::vector<double> ScaleValueVec = {scaleInv};
140+
auto p = dataToQfp(ScaleValueVec, -1, 32, false); // Returns std::make_pair(qfp, fracBits)
141+
int64_t scaleInvQfp = p.first[0];
142+
int scaleInvFracBits = p.second;
143+
144+
constexpr int postMacIntBits = 29;
145+
constexpr int postMacFracBits = 31 - postMacIntBits;
146+
147+
int resultFracBits = postMacFracBits;
148+
int shift = scaleInvFracBits + xFracBits - resultFracBits;
149+
if (shift > 31) {
150+
shift = 31;
151+
resultFracBits = scaleInvFracBits + xFracBits - 31;
152+
}
153+
154+
auto* Y = ctx->Output(0, X->Shape());
155+
int8_t* yData = Y->MutableData<int8_t>();
156+
size_t tensor_size = X->Shape().Size();
157+
for (size_t i = 0; i < tensor_size; ++i) {
158+
int32_t product = fixedPointMultiply(x_data[i], scaleInvQfp, shift);
159+
int32_t productRound = fxRoundPosInf(static_cast<int32_t>(product), static_cast<uint8_t>(resultFracBits));
160+
161+
// Clip and apply zero-point
162+
yData[i] = static_cast<int8_t>(std::min(std::max(productRound + zp, static_cast<int32_t>(std::numeric_limits<int8_t>::min())), static_cast<int32_t>(std::numeric_limits<int8_t>::max())));
163+
}
164+
return Status::OK();
165+
}
166+
167+
} // namespace contrib
168+
} // namespace onnxruntime

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3697,6 +3697,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36973697
"Allow inputs and outputs to be any kind of tensor.");
36983698
#endif
36993699

3700+
// Quadric contrib ops
37003701
ONNX_CONTRIB_OPERATOR_SCHEMA(QuadricCustomOp)
37013702
.SetDomain(kQuadricDomain)
37023703
.SinceVersion(1)
@@ -3715,6 +3716,102 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
37153716
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(),
37163717
"Allow inputs and outputs to be any kind of tensor.");
37173718

3719+
// Quadric ops
3720+
ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinearFixedPoint)
3721+
.SetDomain(kQuadricDomain)
3722+
.SinceVersion(1)
3723+
.SetDoc(R"DOC(
3724+
Dequantizes an int8 input tensor into a fixed-point int32 output tensor using integer arithmetic.
3725+
The dequantization formula is:
3726+
3727+
Y_fixed = ((X - zero_point) * scale_qfp) >> shift
3728+
3729+
where `scale_qfp` is the scale converted into fixed-point representation.
3730+
3731+
- `X` is the quantized input tensor (int8).
3732+
- `scale` is a floating-point scalar that will be inverted (1/scale) and converted into a fixed-point multiplier `scale_qfp`.
3733+
- `zero_point` is the quantization zero-point (int8), which is subtracted from `X` before scaling.
3734+
- `Y_fixed` is the output tensor (int32) interpreted as a fixed-point representation.
3735+
3736+
Unlike `DequantizeLinear`, which produces floating-point outputs, this operator retains
3737+
a fixed-point integer format to align with Quadric's CGC execution.
3738+
3739+
This operator does **per-tensor dequantization**, meaning `scale` and `zero_point` are scalars.
3740+
)DOC")
3741+
3742+
// Inputs
3743+
.Input(0, "X", "N-D quantized input tensor (int8).", "T")
3744+
.Input(1, "scale", "Scalar scale factor (float). Converted to fixed-point format internally.", "T1")
3745+
.Input(2, "zero_point", "Scalar zero-point offset (int8). Must match type of X.", "T2")
3746+
3747+
// Outputs
3748+
.Output(0, "Y", "N-D output tensor (int32). Fixed-point representation.", "T3")
3749+
3750+
// Type Constraints
3751+
.TypeConstraint("T", {"tensor(int8)"}, "Input tensor must be int8.")
3752+
.TypeConstraint("T1", {"tensor(float)"}, "Scale must be a floating-point scalar.")
3753+
.TypeConstraint("T2", {"tensor(int8)"}, "Zero point must be int8, matching the input tensor type.")
3754+
.TypeConstraint("T3", {"tensor(int32)"}, "Output tensor is int32 (fixed-point representation).")
3755+
3756+
// Shape Inference
3757+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
3758+
auto y_type = ctx.getOutputType(0);
3759+
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32);
3760+
3761+
if (!hasInputShape(ctx, 0))
3762+
return;
3763+
3764+
auto& input_shape = getInputShape(ctx, 0);
3765+
updateOutputShape(ctx, 0, input_shape);
3766+
});
3767+
3768+
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinearFixedPoint)
3769+
.SetDomain(kQuadricDomain)
3770+
.SinceVersion(1)
3771+
.SetDoc(R"DOC(
3772+
Quantizes an int32 input tensor into an int8 output tensor using fixed-point arithmetic.
3773+
3774+
The quantization formula is:
3775+
3776+
Y_q = round(((X * scale_inv_qfp) >> shift) + zero_point)
3777+
3778+
where:
3779+
- `X` is the input tensor in int32 (fixed-point representation).
3780+
- `scale_inv_qfp` is the inverse of scale in fixed-point format.
3781+
- `zero_point` is the quantization zero-point.
3782+
- `Y_q` is the quantized output in int8.
3783+
3784+
This operator does **per-tensor quantization**, meaning `scale` and `zero_point` are scalars.
3785+
)DOC")
3786+
3787+
// Inputs
3788+
.Input(0, "X", "N-D input tensor (int32, fixed-point).", "T")
3789+
.Input(1, "x_frac_bits", "Fractional bits of input (int8).", "T1")
3790+
.Input(2, "scale", "Scalar scale factor (float).", "T2")
3791+
.Input(3, "zero_point", "Scalar zero-point offset (int8).", "T3")
3792+
3793+
// Outputs
3794+
.Output(0, "Y", "N-D output tensor (int8).", "T4")
3795+
3796+
// Type Constraints
3797+
.TypeConstraint("T", {"tensor(int32)"}, "Input tensor must be int32.")
3798+
.TypeConstraint("T1", {"tensor(int8)"}, "Fractional bits must be int8.")
3799+
.TypeConstraint("T2", {"tensor(float)"}, "Scale must be a floating-point scalar.")
3800+
.TypeConstraint("T3", {"tensor(int8)"}, "Zero point must be int8, matching the output tensor type.")
3801+
.TypeConstraint("T4", {"tensor(int8)"}, "Output tensor is int8.")
3802+
3803+
// Shape Inference
3804+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
3805+
auto y_type = ctx.getOutputType(0);
3806+
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT8);
3807+
3808+
if (!hasInputShape(ctx, 0))
3809+
return;
3810+
3811+
auto& input_shape = getInputShape(ctx, 0);
3812+
updateOutputShape(ctx, 0, input_shape);
3813+
});
3814+
37183815
#ifdef ENABLE_TRAINING_OPS
37193816
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
37203817
// 2). this is needed by inference for other purpose.

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ Module Name:
1818
#pragma once
1919

2020
#include <cstddef>
21-
#include <cstdlib>
2221
#include <cstdint>
22+
#include <cstdlib>
2323
#include <stdexcept>
24+
#include <vector>
2425

2526
//
2627
// Define the calling convention for Windows targets.
@@ -1268,24 +1269,33 @@ MlasRequantizeOutput(
12681269
size_t CountN
12691270
);
12701271

1271-
template<typename OutputType>
1272+
template <typename OutputType>
12721273
void
1273-
MLASCALL
1274-
MlasRequantizeOutputFixedPoint(
1275-
const int32_t* Input,
1276-
size_t InputLeadingDimension,
1277-
OutputType* Output,
1278-
size_t OutputLeadingDimension,
1279-
const int32_t* Bias,
1280-
const float* Scale,
1281-
bool PerColumnScale,
1282-
OutputType ZeroPoint,
1283-
size_t StartM,
1284-
size_t StartN,
1285-
size_t CountM,
1286-
size_t CountN
1274+
MLASCALL
1275+
MlasRequantizeOutputFixedPoint(
1276+
const int32_t* Input,
1277+
size_t InputLeadingDimension,
1278+
OutputType* Output,
1279+
size_t OutputLeadingDimension,
1280+
const int32_t* Bias,
1281+
const float* Scale,
1282+
bool PerColumnScale,
1283+
OutputType ZeroPoint,
1284+
size_t StartM,
1285+
size_t StartN,
1286+
size_t CountM,
1287+
size_t CountN
12871288
);
12881289

1290+
int32_t
1291+
fxRoundPosInf(const int32_t a, uint8_t aFracBits);
1292+
1293+
template <typename T>
1294+
std::pair<std::vector<int>, int>
1295+
dataToQfp(
1296+
const std::vector<T>& data, int fracBits = -1, int qfpSize = 32, bool scalarAsFloat = true
1297+
);
1298+
12891299
class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
12901300
{
12911301
public:
@@ -1336,7 +1346,6 @@ class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
13361346
bool OutputIsSigned_;
13371347
};
13381348

1339-
13401349
void
13411350
MLASCALL
13421351
MlasFindMinMaxElement(

0 commit comments

Comments
 (0)