-
Notifications
You must be signed in to change notification settings - Fork 1
General LUT node #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
General LUT node #39
Changes from all commits
b641d11
a3ec03d
7770c7e
a238a71
21d41d4
f738d8b
03e3ea5
50d5afd
6f853f5
70e4d38
2d40708
4ea5fad
01bfdd9
d05bbd7
9a96fd9
17d7892
1f76d7a
22d50a6
d746185
f791a22
97e8da6
ab9c98a
4ae29ff
75e930e
b8df013
bcea5b8
664c89b
98d2377
7d2e9f7
542db11
0f3faa3
0fea2d8
b5bcd25
fefc674
0a8a0c5
a536c32
b66e8b0
c28cb28
ba6aad3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| #include "core/util/math.h" | ||
| #include "core/mlas/inc/mlas.h" | ||
| #include <functional> | ||
| #include "core/framework/op_kernel_context_internal.h" | ||
|
|
||
| using onnxruntime::concurrency::ThreadPool; | ||
|
|
||
|
|
@@ -55,6 +56,46 @@ Status ComputeQLinearGlobalAvgPool( | |
| return Status::OK(); | ||
| } | ||
|
|
||
| template <typename T8Bits> | ||
| Status ComputeQLinearGlobalAvgPoolFixedPoint( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is identical to ComputeQLinearGlobalAvgPool except it calls MlasQLinearGlobalAveragePoolNchwFixedPoint instead of MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwcFixedPoint instead of MlasQLinearGlobalAveragePoolNhwc. as discussed with Chris, this could definitely be refactored and deleted so MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwc have a way to determine gpnpu flag inside. however, I believe I did not do this because the flag from session options can only be accessed from the highest level, not down in MlasQLinearGlobalAveragePoolNhwc and MlasQLinearGlobalAveragePoolNchw |
||
| const T8Bits* x, | ||
| float x_scale, | ||
| T8Bits x_zero_point, | ||
| T8Bits* y, | ||
| float y_scale, | ||
| T8Bits y_zero_point, | ||
| int64_t N, | ||
| int64_t C, | ||
| int64_t image_size, | ||
| bool channels_last, | ||
| concurrency::ThreadPool* tp) { | ||
| if (!channels_last || C == 1) { | ||
| auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) { | ||
| const T8Bits* input = (const T8Bits*)(x + (first * image_size)); | ||
| T8Bits* output = (T8Bits*)(y + first); | ||
| std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), last - first)); | ||
| MlasQLinearGlobalAveragePoolNchwFixedPoint(input, x_scale, x_zero_point, output, y_scale, y_zero_point, last - first, narrow<size_t>(image_size), acc_buffer.data()); | ||
| }; | ||
| concurrency::ThreadPool::TryParallelFor( | ||
| tp, static_cast<std::ptrdiff_t>(N * C), {1.0 * image_size, 1.0, 8.0 * image_size}, worker); | ||
| } else { | ||
| auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) { | ||
| const T8Bits* input = x + first * C * image_size; | ||
| T8Bits* output = y + first * C; | ||
| std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), narrow<size_t>(C))); | ||
| std::vector<T8Bits> zero_buffer(MlasQLinearSafePaddingElementCount(sizeof(T8Bits), narrow<size_t>(C)), 0); | ||
| MlasQLinearGlobalAveragePoolNhwcFixedPoint( | ||
| input, x_scale, x_zero_point, output, y_scale, y_zero_point, | ||
| last - first, narrow<size_t>(image_size), narrow<size_t>(C), narrow<size_t>(C), acc_buffer.data(), zero_buffer.data()); | ||
| }; | ||
| concurrency::ThreadPool::TryParallelFor( | ||
| tp, static_cast<std::ptrdiff_t>(N), | ||
| {1.0 * image_size * C, 1.0 * C, 8.0 * image_size * C}, | ||
| worker); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // GCC's unexplained behavior: | ||
| // GCC wouldn't generate corresponding symbols versus function instances below when "--disable-exceptions" | ||
| // and "--minimal-build" are combined on linux build. | ||
|
|
@@ -87,6 +128,32 @@ template Status ComputeQLinearGlobalAvgPool<uint8_t>( | |
| bool channels_last, | ||
| concurrency::ThreadPool* tp); | ||
|
|
||
| template Status ComputeQLinearGlobalAvgPoolFixedPoint<int8_t>( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. template mirroring existing structure from before |
||
| const int8_t* x, | ||
| float x_scale, | ||
| int8_t x_zero_point, | ||
| int8_t* y, | ||
| float y_scale, | ||
| int8_t y_zero_point, | ||
| int64_t N, | ||
| int64_t C, | ||
| int64_t image_size, | ||
| bool channels_last, | ||
| concurrency::ThreadPool* tp); | ||
|
|
||
| template Status ComputeQLinearGlobalAvgPoolFixedPoint<uint8_t>( | ||
| const uint8_t* x, | ||
| float x_scale, | ||
| uint8_t x_zero_point, | ||
| uint8_t* y, | ||
| float y_scale, | ||
| uint8_t y_zero_point, | ||
| int64_t N, | ||
| int64_t C, | ||
| int64_t image_size, | ||
| bool channels_last, | ||
| concurrency::ThreadPool* tp); | ||
|
|
||
| Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const { | ||
| const auto tensor_x_scale = context->Input<Tensor>(1); | ||
| const auto tensor_x_zero_point = context->Input<Tensor>(2); | ||
|
|
@@ -124,14 +191,35 @@ Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const { | |
| const float y_scale = *(tensor_y_scale->Data<float>()); | ||
|
|
||
| auto dtype = X.GetElementType(); | ||
| if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { | ||
| return ComputeQLinearGlobalAvgPool(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()), | ||
| Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if gpnpu, go to fixed point version, else original code |
||
| auto* internal_context = dynamic_cast<OpKernelContextInternal*>(context); | ||
| if (!internal_context) { | ||
| return Status(common::ONNXRUNTIME, common::FAIL, "Failed to cast OpKernelContext to OpKernelContextInternal"); | ||
| } | ||
| const auto& session_options = internal_context->GetSessionState().GetSessionOptions(); | ||
| // Test to see if we have access to enable_gpnpu flag | ||
| const bool gpnpu_flag = session_options.enable_gpnpu; | ||
|
|
||
| if (gpnpu_flag) { | ||
| if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { | ||
| return ComputeQLinearGlobalAvgPoolFixedPoint(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()), | ||
| Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
| } else { | ||
| return ComputeQLinearGlobalAvgPoolFixedPoint(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()), | ||
| Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
| } | ||
| } else { | ||
| return ComputeQLinearGlobalAvgPool(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()), | ||
| Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
| if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { | ||
| return ComputeQLinearGlobalAvgPool(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()), | ||
| Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
| } else { | ||
| return ComputeQLinearGlobalAvgPool(X.Data<int8_t>(), x_scale, *(tensor_x_zero_point->Data<int8_t>()), | ||
| Y.MutableData<int8_t>(), y_scale, *(tensor_y_zero_point->Data<int8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,9 @@ | |
| #include "core/providers/cpu/quantization/matmul_integer_base.h" | ||
| #include "core/quantization/quantization.h" | ||
| #include "core/util/math_cpuonly.h" | ||
| #include "core/util/qmath.h" | ||
| #include "core/mlas/inc/mlas.h" | ||
| #include "core/framework/op_kernel_context_internal.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
@@ -18,6 +21,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { | |
| } | ||
|
|
||
| Status Compute(OpKernelContext* context) const override { | ||
| auto* internal_context = dynamic_cast<OpKernelContextInternal*>(context); | ||
| if (!internal_context) { | ||
| return Status(common::ONNXRUNTIME, common::FAIL, "Failed to cast OpKernelContext to OpKernelContextInternal"); | ||
| } | ||
| const auto& session_options = internal_context->GetSessionState().GetSessionOptions(); | ||
| // Test to see if we have access to enable_gpnpu flag | ||
| const bool gpnpu_flag = session_options.enable_gpnpu; | ||
|
|
||
| const auto* a = context->Input<Tensor>(IN_A); | ||
| const auto* b = packed_b_ ? nullptr : context->Input<Tensor>(IN_B); | ||
| const auto& b_shape = b ? b->Shape() : b_shape_; | ||
|
|
@@ -106,9 +117,17 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { | |
| gemm_param.PerColumnZeroPoints = !IsScalarOr1ElementVector(b_zp); | ||
|
|
||
| std::vector<float> output_scales = ComputeOutputScale(a_scale, b_scale, y_scale); | ||
| std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_proc_ptr; | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. define 2 additional processors for fixed point. as discussed with Chris, this could be refactored |
||
| std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR_FIXEDPOINT> requant_proc_ptr_fixedpoint; | ||
| std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR_FIXEDPOINT> scale_bias_proc_ptr_fixedpoint; | ||
| std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR> requant_proc_ptr; | ||
| SetPostProcessor(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr, requant_proc_ptr); | ||
| std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_proc_ptr; | ||
|
|
||
| if (gpnpu_flag) { | ||
| SetPostProcessorFixedPoint(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr_fixedpoint, requant_proc_ptr_fixedpoint); | ||
| } else { | ||
| SetPostProcessor(y_zp, N, output_scales, y, gemm_param, scale_bias_proc_ptr, requant_proc_ptr); | ||
| } | ||
|
|
||
| MlasGemmBatch(gemm_shape, &gemm_param, 1, context->GetOperatorThreadPool()); | ||
| return Status::OK(); | ||
|
|
@@ -210,6 +229,36 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { | |
| gemm_param.OutputProcessor = &*scale_bias_proc_ptr; | ||
| } | ||
| } | ||
| static void SetPostProcessorFixedPoint(const Tensor* y_zp, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 new processors defined above get used here |
||
| size_t out_lda, | ||
| const std::vector<float>& output_scales, | ||
| Tensor* y, | ||
| MLAS_GEMM_QUANT_DATA_PARAMS& gemm_param, | ||
| std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR_FIXEDPOINT>& scale_bias_proc_ptr, | ||
| std::optional<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR_FIXEDPOINT>& requant_proc_ptr) { | ||
| if (nullptr != y_zp) { | ||
| bool is_y_signed = y->IsDataType<int8_t>(); | ||
| int32_t y_zero_point = is_y_signed ? *y_zp->Data<int8_t>() : *y_zp->Data<uint8_t>(); | ||
| requant_proc_ptr.emplace( | ||
| y->MutableDataRaw(), | ||
| out_lda, | ||
| nullptr, | ||
| output_scales.data(), | ||
| output_scales.size() > 1, | ||
| y_zero_point, | ||
| is_y_signed); | ||
| gemm_param.OutputProcessor = &*requant_proc_ptr; | ||
| } else { | ||
| scale_bias_proc_ptr.emplace( | ||
| static_cast<float*>(y->MutableDataRaw()), | ||
| out_lda, | ||
| output_scales.data(), | ||
| nullptr, | ||
| MLAS_QGEMM_OUTPUT_MODE::ZeroMode, | ||
| output_scales.size() > 1 ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix); | ||
| gemm_param.OutputProcessor = &*scale_bias_proc_ptr; | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ternary for if gpnpu, use MlasQLinearAddFixedPoint inside instead of the original MlasQLinearAdd which is in the else clause