Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b641d11
Added enable_gpnpu option in session_options and included pybindings
maggiesquadric Jan 14, 2025
a3ec03d
Can now access enable_gpnpu flag in qlinearconv.cc
maggiesquadric Jan 14, 2025
7770c7e
Added functions for calculating frac bits
maggiesquadric Jan 16, 2025
a238a71
Debugging
maggiesquadric Jan 16, 2025
21d41d4
Working version
maggiesquadric Jan 17, 2025
f738d8b
Deleted debug prints in quantize.cpp
maggiesquadric Jan 17, 2025
03e3ea5
Added unit test for qlinearconv
maggiesquadric Jan 17, 2025
50d5afd
Cleaning up
maggiesquadric Jan 17, 2025
6f853f5
Edited test_qlinearconv.py
maggiesquadric Jan 17, 2025
70e4d38
Made new folder for gpnpumode tests
maggiesquadric Jan 17, 2025
2d40708
Qlinearadd fixed point version, working on unit test
maggiesquadric Jan 22, 2025
4ea5fad
Added unit test for QLinearAdd
maggiesquadric Jan 22, 2025
01bfdd9
New branch for qgemm
maggiesquadric Jan 23, 2025
d05bbd7
Working on create onnx for qgemm
maggiesquadric Jan 23, 2025
9a96fd9
Working on adding null inputs to onnx graph in test_qgemm.py
maggiesquadric Jan 23, 2025
17d7892
Qgemm working with unit test
maggiesquadric Jan 24, 2025
1f76d7a
Modified test_qgemm.py
maggiesquadric Jan 24, 2025
22d50a6
Added test_qlineargap.py
maggiesquadric Jan 24, 2025
d746185
Adding fixed point functionality
maggiesquadric Jan 24, 2025
f791a22
Added functionality for QLinearGlobalAveragePool
maggiesquadric Jan 27, 2025
97e8da6
Added unit test for QLinearGAP
maggiesquadric Jan 27, 2025
ab9c98a
Edited profiling for qlineargap
maggiesquadric Jan 28, 2025
4ae29ff
Edited profiling for qlineargap
maggiesquadric Jan 28, 2025
75e930e
Edits
maggiesquadric Jan 28, 2025
b8df013
Editing
maggiesquadric Jan 28, 2025
bcea5b8
Edits
maggiesquadric Jan 28, 2025
664c89b
Edited test_resnet50.py
maggiesquadric Jan 29, 2025
98d2377
profiling
maggiesquadric Jan 29, 2025
7d2e9f7
profiling
maggiesquadric Jan 29, 2025
542db11
Edits
maggiesquadric Jan 31, 2025
0f3faa3
Validation
maggiesquadric Jan 31, 2025
0fea2d8
m1?
maggiesquadric Jan 31, 2025
b5bcd25
lut op
maggiesquadric Feb 14, 2025
fefc674
lut op working
maggiesquadric Feb 18, 2025
0a8a0c5
lut mostly working
maggiesquadric Feb 20, 2025
a536c32
Changed lut table to attribute instead of input
maggiesquadric Feb 20, 2025
b66e8b0
Edited for mac to pass Cl
maggiesquadric Feb 21, 2025
c28cb28
new way of registering lut
maggiesquadric Feb 21, 2025
ba6aad3
Cleanign up comments
maggiesquadric Mar 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLUT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLUT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
Expand Down
129 changes: 90 additions & 39 deletions onnxruntime/contrib_ops/cpu/quantization/qlinear_binary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/providers/common.h"
#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
#include "core/framework/op_kernel_context_internal.h"

using onnxruntime::concurrency::ThreadPool;

Expand Down Expand Up @@ -95,45 +96,95 @@ void QLinearImpl(OpKernelContext& context, double unit_cost, const ProcessBroadc

template <typename T>
Status QLinearAdd<T>::Compute(OpKernelContext* context) const {
const ProcessBroadcastSpanFuncs functors = {
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
const T input0 = per_iter_bh.ScalarInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAdd(input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
&input0,
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
const T input1 = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();
MlasQLinearAdd(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
&input1,
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAdd(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), false);
}};
auto* internal_context = dynamic_cast<OpKernelContextInternal*>(context);
if (!internal_context) {
return Status(common::ONNXRUNTIME, common::FAIL, "Failed to cast OpKernelContext to OpKernelContextInternal");
}
const auto& session_options = internal_context->GetSessionState().GetSessionOptions();
// Test to see if we have access to enable_gpnpu flag
const bool gpnpu_flag = session_options.enable_gpnpu;

const ProcessBroadcastSpanFuncs functors = gpnpu_flag ? ProcessBroadcastSpanFuncs{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ternary for if gpnpu, use MlasQLinearAddFixedPoint inside instead of the original MlasQLinearAdd which is in the else clause

[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
const T input0 = per_iter_bh.ScalarInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAddFixedPoint(input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
&input0,
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
const T input1 = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAddFixedPoint(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
&input1,
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAddFixedPoint(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), false);
}
} : ProcessBroadcastSpanFuncs{
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
const T input0 = per_iter_bh.ScalarInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAdd(input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
&input0,
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
const T input1 = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAdd(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
&input1,
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), true);
},
[](BroadcastHelper& per_iter_bh) {
QLinearBroadcastHelper& qlbh = static_cast<QLinearBroadcastHelper&>(per_iter_bh);
auto input0 = per_iter_bh.SpanInput0<T>();
auto input1 = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

MlasQLinearAdd(input0.data(),
qlbh.A_scale, static_cast<T>(qlbh.A_zero_point),
input1.data(),
qlbh.B_scale, static_cast<T>(qlbh.B_zero_point),
qlbh.C_scale, static_cast<T>(qlbh.C_zero_point),
output.data(), output.size(), false);
}
};

QLinearImpl<T>(*context, 1.0, functors);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"
#include <functional>
#include "core/framework/op_kernel_context_internal.h"

using onnxruntime::concurrency::ThreadPool;

Expand Down Expand Up @@ -55,6 +56,46 @@ Status ComputeQLinearGlobalAvgPool(
return Status::OK();
}

template <typename T8Bits>
Status ComputeQLinearGlobalAvgPoolFixedPoint(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is identical to ComputeQLinearGlobalAvgPool except it calls MlasQLinearGlobalAveragePoolNchwFixedPoint instead of MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwcFixedPoint instead of MlasQLinearGlobalAveragePoolNhwc. as discussed with Chris, this could definitely be refactored and deleted so MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwc have a way to determine gpnpu flag inside. however, I believe I did not do this because the flag from session options can only be accessed from the highest level, not down in MlasQLinearGlobalAveragePoolNhwc and MlasQLinearGlobalAveragePoolNchw

const T8Bits* x,
float x_scale,
T8Bits x_zero_point,
T8Bits* y,
float y_scale,
T8Bits y_zero_point,
int64_t N,
int64_t C,
int64_t image_size,
bool channels_last,
concurrency::ThreadPool* tp) {
if (!channels_last || C == 1) {
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
const T8Bits* input = (const T8Bits*)(x + (first * image_size));
T8Bits* output = (T8Bits*)(y + first);
std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), last - first));
MlasQLinearGlobalAveragePoolNchwFixedPoint(input, x_scale, x_zero_point, output, y_scale, y_zero_point, last - first, narrow<size_t>(image_size), acc_buffer.data());
};
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(N * C), {1.0 * image_size, 1.0, 8.0 * image_size}, worker);
} else {
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
const T8Bits* input = x + first * C * image_size;
T8Bits* output = y + first * C;
std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), narrow<size_t>(C)));
std::vector<T8Bits> zero_buffer(MlasQLinearSafePaddingElementCount(sizeof(T8Bits), narrow<size_t>(C)), 0);
MlasQLinearGlobalAveragePoolNhwcFixedPoint(
input, x_scale, x_zero_point, output, y_scale, y_zero_point,
last - first, narrow<size_t>(image_size), narrow<size_t>(C), narrow<size_t>(C), acc_buffer.data(), zero_buffer.data());
};
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(N),
{1.0 * image_size * C, 1.0 * C, 8.0 * image_size * C},
worker);
}
return Status::OK();
}

// GCC's unexplained behavior:
// GCC wouldn't generate corresponding symbols versus function instances below when "--disable-exceptions"
// and "--minimal-build" are combined on linux build.
Expand Down Expand Up @@ -87,6 +128,32 @@ template Status ComputeQLinearGlobalAvgPool<uint8_t>(
bool channels_last,
concurrency::ThreadPool* tp);

template Status ComputeQLinearGlobalAvgPoolFixedPoint<int8_t>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

template mirroring existing structure from before

const int8_t* x,
float x_scale,
int8_t x_zero_point,
int8_t* y,
float y_scale,
int8_t y_zero_point,
int64_t N,
int64_t C,
int64_t image_size,
bool channels_last,
concurrency::ThreadPool* tp);

template Status ComputeQLinearGlobalAvgPoolFixedPoint<uint8_t>(
const uint8_t* x,
float x_scale,
uint8_t x_zero_point,
uint8_t* y,
float y_scale,
uint8_t y_zero_point,
int64_t N,
int64_t C,
int64_t image_size,
bool channels_last,
concurrency::ThreadPool* tp);

Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const {
const auto tensor_x_scale = context->Input<Tensor>(1);
const auto tensor_x_zero_point = context->Input<Tensor>(2);
Expand Down Expand Up @@ -124,14 +191,35 @@ Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const {
const float y_scale = *(tensor_y_scale->Data<float>());

auto dtype = X.GetElementType();
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
return ComputeQLinearGlobalAvgPool(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()),
Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()),
N, C, image_size, channels_last_, tp);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if gpnpu, go to fixed point version, else original code

auto* internal_context = dynamic_cast<OpKernelContextInternal*>(context);
if (!internal_context) {
return Status(common::ONNXRUNTIME, common::FAIL, "Failed to cast OpKernelContext to OpKernelContextInternal");
}
const auto& session_options = internal_context->GetSessionState().GetSessionOptions();
// Test to see if we have access to enable_gpnpu flag
const bool gpnpu_flag = session_options.enable_gpnpu;

if (gpnpu_flag) {
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
return ComputeQLinearGlobalAvgPoolFixedPoint(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()),
Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()),
N, C, image_size, channels_last_, tp);
} else {
return ComputeQLinearGlobalAvgPoolFixedPoint(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()),
Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()),
N, C, image_size, channels_last_, tp);
}
} else {
return ComputeQLinearGlobalAvgPool(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()),
Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()),
N, C, image_size, channels_last_, tp);
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
return ComputeQLinearGlobalAvgPool(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()),
Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()),
N, C, image_size, channels_last_, tp);
} else {
return ComputeQLinearGlobalAvgPool(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()),
Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()),
N, C, image_size, channels_last_, tp);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,19 @@ Status ComputeQLinearGlobalAvgPool(
bool channels_last,
concurrency::ThreadPool* tp);

template <typename T8Bits>
Status ComputeQLinearGlobalAvgPoolFixedPoint(
const T8Bits* x,
float x_scale,
T8Bits x_zero_point,
T8Bits* y,
float y_scale,
T8Bits y_zero_point,
int64_t N,
int64_t C,
int64_t image_size,
bool channels_last,
concurrency::ThreadPool* tp);

} // namespace contrib
} // namespace onnxruntime
53 changes: 51 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "core/providers/cpu/quantization/matmul_integer_base.h"
#include "core/quantization/quantization.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"
#include "core/mlas/inc/mlas.h"
#include "core/framework/op_kernel_context_internal.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -18,6 +21,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
}

Status Compute(OpKernelContext* context) const override {
auto* internal_context = dynamic_cast<OpKernelContextInternal*>(context);
if (!internal_context) {
return Status(common::ONNXRUNTIME, common::FAIL, "Failed to cast OpKernelContext to OpKernelContextInternal");
}
const auto& session_options = internal_context->GetSessionState().GetSessionOptions();
// Test to see if we have access to enable_gpnpu flag
const bool gpnpu_flag = session_options.enable_gpnpu;

const auto* a = context->Input<Tensor>(IN_A);
const auto* b = packed_b_ ? nullptr : context->Input<Tensor>(IN_B);
const auto& b_shape = b ? b->Shape() : b_shape_;
Expand Down Expand Up @@ -106,9 +117,17 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
gemm_param.PerColumnZeroPoints = !IsScalarOr1ElementVector(b_zp);

std::vector<float> output_scales = ComputeOutputScale(a_scale, b_scale, y_scale);
std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_proc_ptr;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define 2 additional processors for fixed point. as discussed with Chris, this could be refactored

std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR_FIXEDPOINT> requant_proc_ptr_fixedpoint;
std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR_FIXEDPOINT> scale_bias_proc_ptr_fixedpoint;
std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR> requant_proc_ptr;
SetPostProcessor(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr, requant_proc_ptr);
std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_proc_ptr;

if (gpnpu_flag) {
SetPostProcessorFixedPoint(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr_fixedpoint, requant_proc_ptr_fixedpoint);
} else {
SetPostProcessor(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr, requant_proc_ptr);
}

MlasGemmBatch(gemm_shape, &gemm_param, 1, context->GetOperatorThreadPool());
return Status::OK();
Expand Down Expand Up @@ -210,6 +229,36 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
gemm_param.OutputProcessor = &*scale_bias_proc_ptr;
}
}
static void SetPostProcessorFixedPoint(const Tensor* y_zp,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 new processors defined above get used here

size_t out_lda,
const std::vector<float>& output_scales,
Tensor* y,
MLAS_GEMM_QUANT_DATA_PARAMS& gemm_param,
std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR_FIXEDPOINT>& scale_bias_proc_ptr,
std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR_FIXEDPOINT>& requant_proc_ptr) {
if (nullptr != y_zp) {
bool is_y_signed = y->IsDataType<int8_t>();
int32_t y_zero_point = is_y_signed ? *y_zp->Data<int8_t>() : *y_zp->Data<uint8_t>();
requant_proc_ptr.emplace(
y->MutableDataRaw(),
out_lda,
nullptr,
output_scales.data(),
output_scales.size() > 1,
y_zero_point,
is_y_signed);
gemm_param.OutputProcessor = &*requant_proc_ptr;
} else {
scale_bias_proc_ptr.emplace(
static_cast<float*>(y->MutableDataRaw()),
out_lda,
output_scales.data(),
nullptr,
MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
output_scales.size() > 1 ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix);
gemm_param.OutputProcessor = &*scale_bias_proc_ptr;
}
}
};

ONNX_OPERATOR_TYPED_KERNEL_EX(
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/framework/op_kernel_context_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class OpKernelContextInternal : public OpKernelContext {
return session_state_.GetUseDeterministicCompute();
}

// Add a getter method for session_state_
const SessionState& GetSessionState() const {
return session_state_;
}

const SessionState* SubgraphSessionState(const std::string& attribute_name) {
return session_state_.GetSubgraphSessionState(GetNodeIndex(), attribute_name);
}
Expand Down
Loading