From c9fc8a95f5e3c78b293ed7981bcaa5fdaec4d52c Mon Sep 17 00:00:00 2001 From: mojave2 Date: Sun, 28 Sep 2025 16:51:14 +0800 Subject: [PATCH] add mla_preprocess kernel Signed-off-by: mojave2 --- .pre-commit-config.yaml | 8 +- CMakeLists.txt | 6 + csrc/mla_preprocess/op_host/mla_preprocess.h | 698 ++++ .../op_host/tiling/mla_preprocess_tiling.h | 95 + csrc/mla_preprocess/op_kernel/kernel/common.h | 25 + .../op_kernel/kernel/common_func.h | 121 + .../op_kernel/kernel/hardware.h | 36 + .../op_kernel/kernel/iterator.h | 92 + .../kernel/iterators/gm_to_l1_iterator.inc | 162 + .../kernel/iterators/gm_to_ub_iterator.inc | 89 + .../kernel/iterators/l0c_to_gm_iterator.inc | 228 ++ .../kernel/iterators/l0c_to_l1_iterator.inc | 42 + .../kernel/iterators/l0c_to_ub_iterator.inc | 71 + .../kernel/iterators/l1_to_bt_iterator.inc | 39 + .../kernel/iterators/l1_to_fb_iterator.inc | 36 + .../kernel/iterators/l1_to_l0_iterator.inc | 310 ++ .../kernel/iterators/l1_to_ub_iterator.inc | 44 + .../op_kernel/kernel/kernel_utils.h | 395 +++ csrc/mla_preprocess/op_kernel/kernel/layout.h | 18 + csrc/mla_preprocess/op_kernel/kernel/mem.h | 82 + csrc/mla_preprocess/op_kernel/kernel/mma.h | 67 + .../mla_preprocess/op_kernel/kernel/set_fpc.h | 38 + csrc/mla_preprocess/op_kernel/kernel/simd.h | 274 ++ csrc/mla_preprocess/op_kernel/kernel/utils.h | 69 + .../mla_preprocess/op_kernel/mla_preprocess.h | 114 + .../op_kernel/mla_preprocess_kernel.cpp | 297 ++ .../op_kernel/mla_preprocess_mix_bf16.hpp | 2918 +++++++++++++++++ .../op_kernel/mla_preprocess_mix_fp16.hpp | 2508 ++++++++++++++ csrc/ops.h | 36 + csrc/torch_binding.cpp | 91 + csrc/torch_binding_meta.cpp | 38 +- 31 files changed, 9044 insertions(+), 3 deletions(-) create mode 100644 csrc/mla_preprocess/op_host/mla_preprocess.h create mode 100644 csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/common.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/common_func.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/hardware.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterator.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_l1_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_ub_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_gm_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_l1_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_ub_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_bt_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_fb_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_l0_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_ub_iterator.inc create mode 100644 csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/layout.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/mem.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/mma.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/set_fpc.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/simd.h create mode 100644 csrc/mla_preprocess/op_kernel/kernel/utils.h create mode 100644 csrc/mla_preprocess/op_kernel/mla_preprocess.h create mode 100644 csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp create mode 100644 csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp create mode 100644 csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bdfdb4c141..975303554a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,8 @@ repos: - id: codespell args: [ --toml, pyproject.toml, - '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', - '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn' + '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', + '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND' ] additional_dependencies: - tomli @@ -35,6 +35,10 @@ repos: rev: v1.32.0 hooks: - id: typos + args: [ + "--force-exclude", + "--exclude", "csrc/mla_preprocess/**" + ] - repo: https://github.com/PyCQA/isort rev: 6.0.1 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d06c75f29..b64611df73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,11 +44,13 @@ else() endif() include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp) ascendc_library(vllm_ascend_kernels SHARED ${KERNEL_FILES} + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp ) message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") @@ -90,7 +92,11 @@ target_link_libraries( libtorch_npu.so vllm_ascend_kernels ascendcl + tiling_api + register platform + ascendalog + dl ) target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.h b/csrc/mla_preprocess/op_host/mla_preprocess.h new file mode 100644 index 0000000000..66ad8c33b3 --- /dev/null +++ b/csrc/mla_preprocess/op_host/mla_preprocess.h @@ -0,0 +1,698 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost.git +// https://gitee.com/ascend/op-plugin.git +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include +#include +#include +#include +#include "acl/acl.h" +// #include "defines.h" +// #include "torch_helper.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/mla_preprocess_tiling.h" + +// #include "aclrtlaunch_mla_preprocess.h" + +// namespace sglang { +namespace mlapo { + +constexpr uint32_t DIM_2 = 2; + +constexpr uint32_t AXES_ALIGN_SIZE = 512; +constexpr uint32_t BASE_BLOCK_STEP = 2; +constexpr uint32_t CONST_16 = 16; +constexpr uint32_t CONST_32 = 32; +constexpr uint32_t CONST_128 = 128; +constexpr uint32_t CONST_256 = 256; +constexpr uint32_t CONST_512 = 512; +constexpr uint32_t L1_BUFFER_SIZE = 524288; +constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 262144; +constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN = 131072; +constexpr uint32_t L1_SCALE_SIZE = 4096; +constexpr uint32_t L1_BIAS_SIZE = 2048; +constexpr uint32_t L0C_SIZE = 128 * 1024; +constexpr uint32_t CONCAT_SIZE = 512; + +constexpr uint32_t HIDDEN_STRATE = 7168; +constexpr uint32_t HIDDEN_STRATE_ROPE = 192; +constexpr uint32_t HIDDEN_STRATE_MM = 2112; +constexpr uint32_t HIDDEN_STRATE_RMS = 1536; +constexpr uint32_t UB_SIZE = 196352; +constexpr uint32_t HEADDIM = 64; +constexpr uint32_t FP32_REPEAT_MASK = 64; +constexpr uint32_t FP16_REPEAT_MASK = 128; + +constexpr int32_t NUM1 = 1; +constexpr int32_t NUM2 = 2; +constexpr int32_t NUM3 = 3; +constexpr int32_t NUM4 = 4; +constexpr int32_t NUM8 = 8; +constexpr uint32_t INDEX_WDQKV = 5; +constexpr uint32_t INDEX_WUQ = 18; +constexpr uint32_t INDEX_WUK = 20; + +constexpr uint32_t MAX_SUPPORT_TOKEN_NUMS = 1024; + +inline uint32_t CeilDiv(const uint32_t dividend, const uint32_t divisor) +{ + if (divisor == 0) { + return UINT32_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +inline uint32_t RoundUp(const uint32_t val, const uint32_t align = 16) +{ + if (align == 0) { + return 0; + } + return (val + align - 1) / align * align; +} + +inline uint32_t RoundDown(const uint32_t val, const uint32_t align = 16) +{ + if (align == 0) { + return 0; + } + return val / align * align; +} + +template +inline T Max(const T a, const T b) +{ + return a > b ? a : b; +} + +template +inline T Min(const T a, const T b) +{ + return a < b ? a : b; +} + +struct MlaPreprocess { + enum class QuantMode : int32_t { + PER_TENSOR_ASYMM_QUANT = 0, + PER_TOKEN_SYMM_QUANT, + PER_TOKEN_ASYMM_QUANT, + NO_QUANT + }; +}; +using QuantMode = MlaPreprocess::QuantMode; + +struct PlatformInfo { + uint32_t coreNum; + uint32_t coreNumAic; + uint32_t coreNumAiv; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l2Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; +}; + +struct OpParam { + uint32_t N; + uint32_t headNum; + int32_t cacheMode; + QuantMode quantMode; + caffe2::TypeMeta inDtype; +}; + +class PpMatmulTilingApi +{ +public: + PpMatmulTilingApi(struct PlatformInfo &platformInfo, uint32_t numBatch, uint32_t m, uint32_t k, uint32_t n, + bool transA, bool transB, bool enDequant, bool deqOnTheFly) + : platformInfo_(platformInfo), + numBatch_(numBatch), + m_(m), + k_(k), + n_(n), + transA_(transA), + transB_(transB), + enDequant_(enDequant), + deqOnTheFly_(deqOnTheFly) + { + inDataSize_ = enDequant ? sizeof(uint8_t) : sizeof(uint16_t); + } + void GetTilingData(PpMatmulTilingData &tiling); + +private: + void GetTileSize(); + float GetCost(const uint32_t m0, const uint32_t n0); + void UpdateTileSize(const uint32_t m0, const uint32_t n0); + void Swizzle(); + uint32_t ComputeL1AbSize(); + uint32_t ComputeK0ForABpingpong(uint32_t l1AbSize); + bool IsLoadAllAmat(uint32_t l1AbSize); + uint32_t ComputeK0ForOnlyBpingpong(uint32_t l1AbSize); + +private: + uint32_t numBatch_{0}; + uint32_t m_{0}; + uint32_t k_{0}; + uint32_t n_{0}; + uint32_t m0_{0}; + uint32_t k0_{0}; + uint32_t n0_{0}; + uint32_t mLoop_{0}; + uint32_t kLoop_{0}; + uint32_t nLoop_{0}; + uint32_t coreLoop_{0}; + uint32_t swizzleCount_{0}; + uint32_t blockDim_{0}; + uint32_t swizzleDirect_{0}; + uint32_t inDataSize_{0}; + uint32_t b0matPingPongBufferLen_{L1_PINGPONG_BUFFER_LEN}; + bool transA_{false}; + bool transB_{false}; + bool enDequant_{false}; + bool enShuffleK_{false}; + bool enLoadAllAmat_{false}; + bool deqOnTheFly_{false}; + + struct PlatformInfo platformInfo_; +}; + +void PpMatmulTilingApi::GetTilingData(PpMatmulTilingData &tiling) +{ + GetTileSize(); + tiling.numBatch = numBatch_; + tiling.m = m_; + tiling.k = k_; + tiling.n = n_; + tiling.m0 = m0_; + tiling.k0 = k0_; + tiling.n0 = n0_; + tiling.mLoop = mLoop_; + tiling.kLoop = kLoop_; + tiling.nLoop = nLoop_; + tiling.coreLoop = coreLoop_; + tiling.swizzleCount = swizzleCount_; + tiling.swizzleDirect = swizzleDirect_; + tiling.enShuffleK = static_cast(enShuffleK_); + tiling.blockDim = blockDim_; + tiling.enLoadAllAmat = static_cast(enLoadAllAmat_); + tiling.b0matPingPongBufferLen = b0matPingPongBufferLen_; +} + +void PpMatmulTilingApi::GetTileSize() +{ + bool priFlag = !(m_ < n_); + uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16; + uint32_t priAxes = RoundUp(priFlag ? m_ : n_, CONST_16); + uint32_t subAxes = RoundUp(priFlag ? n_ : m_, roundBase); + float minCost = __FLT_MAX__; + uint32_t maxAxes0 = AXES_ALIGN_SIZE; + uint32_t maxPriAxes0 = Min(maxAxes0, priAxes); + uint32_t maxSubAxes0 = Min(maxAxes0, subAxes); + for (uint32_t priAxes0 = CONST_16; priAxes0 <= maxPriAxes0; priAxes0 *= BASE_BLOCK_STEP) { + for (uint32_t subAxes0 = CONST_16; subAxes0 <= maxSubAxes0; subAxes0 *= BASE_BLOCK_STEP) { + if (priAxes0 * subAxes0 * sizeof(float) > platformInfo_.l0cSize) { + continue; + } + uint32_t newM0 = priFlag ? priAxes0 : subAxes0; + uint32_t newN0 = priFlag ? subAxes0 : priAxes0; + if (newN0 > CONST_256 && enDequant_) { + continue; + } + float cost = GetCost(newM0, newN0); + if (cost < minCost) { + minCost = cost; + UpdateTileSize(newM0, newN0); + } + } + } + + Swizzle(); + + uint32_t l1AbSize = ComputeL1AbSize(); + k0_ = ComputeK0ForABpingpong(l1AbSize); + kLoop_ = CeilDiv(k_, k0_); +} + +uint32_t PpMatmulTilingApi::ComputeK0ForOnlyBpingpong(uint32_t l1AbSize) +{ + enLoadAllAmat_ = true; + b0matPingPongBufferLen_ = static_cast( + static_cast((l1AbSize - RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) / DIM_2)); + uint32_t k0MaxB0 = + static_cast(static_cast(b0matPingPongBufferLen_ / (RoundUp(n0_, CONST_16) * inDataSize_))); + uint32_t k0B0 = k0MaxB0 < CONST_512 ? RoundDown(k0MaxB0, CONST_32) : RoundDown(k0MaxB0, CONST_512); + return k0B0 > CONST_512 ? RoundDown(k0B0, CONST_512) : k0B0; +} + +bool PpMatmulTilingApi::IsLoadAllAmat(uint32_t l1AbSize) +{ + return (coreLoop_ > blockDim_) && enDequant_ && (kLoop_ > 1) && + (l1AbSize > RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) && (mLoop_ == 1); +} + +uint32_t PpMatmulTilingApi::ComputeK0ForABpingpong(uint32_t l1AbSize) +{ + uint32_t k0Max = static_cast(static_cast(l1AbSize / DIM_2) / ((m0_ + n0_) * inDataSize_)); + uint32_t tmpK0; + if (enDequant_) { + tmpK0 = k0Max < CONST_512 ? RoundDown(k0Max, CONST_32) : RoundDown(k0Max, CONST_512); + } else { + tmpK0 = k0Max < CONST_256 ? RoundDown(k0Max, CONST_16) : RoundDown(k0Max, CONST_256); + } + if (tmpK0 > CONST_512) { + tmpK0 = RoundDown(tmpK0, CONST_512); + } + return tmpK0; +} + +uint32_t PpMatmulTilingApi::ComputeL1AbSize() +{ + if (enDequant_ && deqOnTheFly_) { + return L1_BUFFER_SIZE; + } + return enDequant_ ? (L1_BUFFER_SIZE - L1_BIAS_SIZE - L1_SCALE_SIZE) : L1_BUFFER_SIZE; +} + +float PpMatmulTilingApi::GetCost(const uint32_t m0, const uint32_t n0) +{ + float aCoef = 1.0; + float bCoef = 1.0; + float bwCoef = 5.0; + uint32_t mLoop = CeilDiv(m_, m0); + uint32_t nLoop = CeilDiv(n_, n0); + if (mLoop == 0 || nLoop == 0) { + return __FLT_MAX__; + } + uint32_t rqdNumCore = numBatch_ * mLoop * nLoop; + uint32_t blockDim = Min(rqdNumCore, platformInfo_.coreNumAic); + uint32_t mOnce = blockDim < nLoop ? m0 : blockDim / nLoop * m0; + uint32_t nOnce = blockDim < nLoop ? platformInfo_.coreNumAic * n0 : n_; + if (mOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) { + aCoef = bwCoef; + } + if (nOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) { + bCoef = bwCoef; + } + if (transA_ && m0 % CONST_256 == 0) { + aCoef *= NUM2; + } + if (!transB_ && n0 % CONST_256 == 0) { + bCoef *= NUM2; + } + return 1 / (aCoef * static_cast(n0)) + 1 / (bCoef * static_cast(m0)); +} + +void PpMatmulTilingApi::UpdateTileSize(const uint32_t m0, const uint32_t n0) +{ + m0_ = m0; + n0_ = n0; + mLoop_ = CeilDiv(m_, m0_); + nLoop_ = CeilDiv(n_, n0_); + coreLoop_ = numBatch_ * mLoop_ * nLoop_; + const uint32_t maxNumCubeCore = platformInfo_.coreNumAic; + if (mLoop_ == 1 && transB_ && coreLoop_ % maxNumCubeCore < maxNumCubeCore / NUM4 * NUM3) { + uint32_t tmpM0 = RoundUp(m_, CONST_16); + uint32_t maxN0 = L0C_SIZE / (tmpM0 * sizeof(float)); + if (enDequant_) { + maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256; + } + uint32_t x = CeilDiv(n_, maxNumCubeCore); + uint32_t y = CeilDiv(x, maxN0); + uint32_t tmpN0 = RoundUp(CeilDiv(x, y), CONST_16); + uint32_t rqdL0cSize = tmpM0 * tmpN0 * sizeof(float); + if (rqdL0cSize < L0C_SIZE && (tmpM0 + tmpN0) * CONST_256 * inDataSize_ < L1_BUFFER_SIZE) { + m0_ = tmpM0; + n0_ = tmpN0; + nLoop_ = CeilDiv(n_, n0_); + coreLoop_ = numBatch_ * nLoop_; + } + } + blockDim_ = Min(coreLoop_, maxNumCubeCore); +} + +void PpMatmulTilingApi::Swizzle() +{ + float minCost = m_ * k_ + k_ * n_; + for (uint32_t i = 1; i <= blockDim_; ++i) { + int c = static_cast((blockDim_ + i - 1) / i); + float cost; + // B0 + A < A0 + B + if (i * n0_ + m_ < m0_ * c + n_) { + swizzleDirect_ = 1; // Nz + cost = n0_ * i + m0_ * c; + if (cost <= minCost) { + minCost = cost; + swizzleCount_ = i; + } + } else { + swizzleDirect_ = 0; // Zn + cost = m0_ * i + n0_ * c; + if (cost < minCost) { + minCost = cost; + swizzleCount_ = i; + } + } + } +} + +class MlaPreprocessTiling +{ +public: + MlaPreprocessTiling(struct PlatformInfo &platformInfo, struct OpParam &opParam, MlaTilingData *tilingData) + { + this->tilingData = tilingData; + this->platformInfo = platformInfo; + this->opParam = opParam; + } + void Init(); + + void RmsNormQuantTiling(); + void RopeConcatTiling(); + void EinSumQuantTiling(); + + void SetTilingKey(); + void SetMlapoWorkSpace(); + +private: + MlaTilingData *tilingData; + struct PlatformInfo platformInfo; + struct OpParam opParam; +}; + +void MlaPreprocessTiling::RmsNormQuantTiling() +{ + tilingData->rmsNumCore1 = platformInfo.coreNumAiv; + tilingData->rmsNumCol1 = HIDDEN_STRATE; + tilingData->rmsNumRow1 = opParam.N; + tilingData->rmsQuantMin1 = -CONST_128; + tilingData->rmsNumCore2 = platformInfo.coreNumAiv; + tilingData->rmsNumCol2 = HIDDEN_STRATE_MM; + tilingData->rmsNumRow2 = opParam.N; + tilingData->rmsQuantMin2 = -CONST_128; +} + +void MlaPreprocessTiling::RopeConcatTiling() +{ + uint32_t ntokens = opParam.N; + uint32_t hiddenSizeQ = HEADDIM * opParam.headNum; + uint32_t headDim = HEADDIM; + uint32_t headNumQ = hiddenSizeQ / headDim; + uint32_t concatSize = CONCAT_SIZE; + uint32_t maxCore = platformInfo.coreNumAiv; + uint32_t maxUbSize = platformInfo.ubSize; + + uint32_t allHeadNum = ntokens * headNumQ; + + uint32_t tempCore = (allHeadNum + maxCore - 1) / maxCore; + uint32_t realCore = (allHeadNum + tempCore - 1) / tempCore; // Actual number of the core for operation + uint32_t nlCoreRun = (allHeadNum + realCore - 1) / realCore; // The number of heads in the front core + uint32_t lCoreRun = allHeadNum - (realCore - 1) * nlCoreRun; // The number of heads in the tail core + + uint32_t dataTypeSize = 2; + + // Calculate how many lines can be moved at a time. q 4+2、reverseq 4、neg 4、sin 4+2、cos 4+2 + concat 2 + uint32_t allSize = + headDim * (3 * (4 + dataTypeSize) + 2 * 4) + concatSize * dataTypeSize; // lift precision calculation of ROPE + uint32_t maxNPerLoopForUb = maxUbSize / allSize; // the maximum number of rows at a time for UB + uint32_t preCoreLoopTime = (nlCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of front core + uint32_t preCoreLoopNLast = + nlCoreRun - + (preCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the front core + uint32_t lastCoreLoopTime = (lCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of tail core + uint32_t lastCoreLoopNLast = + lCoreRun - + (lastCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the tail core + + tilingData->hiddenSizeQ = hiddenSizeQ; + tilingData->headNumQ = headNumQ; + tilingData->headDim = headDim; + tilingData->concatSize = concatSize; + tilingData->rotaryCoeff = NUM2; + tilingData->ntokens = ntokens; + tilingData->realCore = realCore; + tilingData->nlCoreRun = nlCoreRun; + tilingData->lCoreRun = nlCoreRun; + tilingData->maxNPerLoopForUb = maxNPerLoopForUb; + tilingData->preCoreLoopTime = preCoreLoopTime; + tilingData->preCoreLoopNLast = preCoreLoopNLast; + tilingData->lastCoreLoopTime = lastCoreLoopTime; + tilingData->lastCoreLoopNLast = lastCoreLoopNLast; +} + +void MlaPreprocessTiling::EinSumQuantTiling() +{ + uint32_t aivCore = platformInfo.coreNumAiv; + uint32_t ubSize = UB_SIZE - 1024; + + // input shape + uint32_t esqBatch = opParam.N; // tokenNum + uint32_t esqHeadNum = opParam.headNum; // headNum + uint32_t esqColNum = AXES_ALIGN_SIZE; // 512 + + // split core + uint32_t esqFrontCore = esqBatch % aivCore; + uint32_t esqTailCore = aivCore - esqFrontCore; + uint32_t esqFrontCoreBatch = CeilDiv(esqBatch, aivCore); + uint32_t esqTailCoreBatch = esqBatch / aivCore; + + // split ub --> calc H' <-- The number of rows handled in a UB cycle. + uint32_t splitFactor = 0; + uint32_t esqHeadPerLoop = 0; // The number of head rows per UB calculation + uint32_t repeatMask = 0; + + if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + // Move scales in at once, broadcast, and cache them all H * 32bytes + uint32_t scaleUb = RoundUp(esqHeadNum) * CONST_32; + // bf16 input [H', colNum](f16 + fp32 + int8), ub reuse + splitFactor = esqColNum * (sizeof(uint16_t) + sizeof(float) + sizeof(uint8_t)); + splitFactor *= NUM2; + esqHeadPerLoop = (ubSize - scaleUb) / splitFactor; // 26 + repeatMask = FP32_REPEAT_MASK; + } else { + // fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16) + splitFactor = + esqColNum * (NUM2 * sizeof(uint16_t) + sizeof(uint8_t)) + sizeof(uint16_t) + (CONST_16 * sizeof(uint16_t)); + esqHeadPerLoop = ubSize / splitFactor; + repeatMask = FP16_REPEAT_MASK; + esqHeadPerLoop = RoundDown(esqHeadPerLoop); + } + uint32_t esqUbHeadLoop = esqHeadNum / esqHeadPerLoop; // UB complete cycles + uint32_t esqHeadTail = esqHeadNum % esqHeadPerLoop; // The number of rows that UB last processed the head. + uint32_t esqColLoop = esqColNum / repeatMask; // Each row counts the number of times to cycle through columns. + uint32_t esqColTail = + esqColNum % repeatMask; // colNum is not 64/128 aligned, the number of columns is calculated last. + + tilingData->esqFrontCore = esqFrontCore; + tilingData->esqTailCore = esqTailCore; + tilingData->esqFrontCoreBatch = esqFrontCoreBatch; + tilingData->esqTailCoreBatch = esqTailCoreBatch; + tilingData->esqHeadNum = esqHeadNum; + tilingData->esqColNum = esqColNum; + tilingData->esqUbHeadLoop = esqUbHeadLoop; + tilingData->esqHeadPerLoop = esqHeadPerLoop; + tilingData->esqHeadTail = esqHeadTail; + tilingData->esqColLoop = esqColLoop; + tilingData->esqColTail = esqColTail; +} + +void MlaPreprocessTiling::SetMlapoWorkSpace() +{ + uint64_t s1wsFactor = + static_cast(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t), + opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) + : HIDDEN_STRATE * sizeof(int8_t)); + uint64_t workSizeS1 = s1wsFactor; + uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); + uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); + uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t); + + uint64_t maxWorkspaceSize = workSizeS1; + maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2); + maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS3); + maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS4); + maxWorkspaceSize *= static_cast(opParam.N); + + uint64_t pertokenWorkspace = static_cast(opParam.N) * sizeof(float) * 2; + + uint64_t userWorkspaceSize; + if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace; + } else { + userWorkspaceSize = 3 * maxWorkspaceSize; + } + + tilingData->userWorkspaceSize = userWorkspaceSize; + tilingData->s1Offset = 0; + tilingData->s2Offset = tilingData->s1Offset + maxWorkspaceSize; + tilingData->s3Offset = tilingData->s2Offset + maxWorkspaceSize; + tilingData->s4Offset = tilingData->s3Offset + maxWorkspaceSize; + tilingData->s5Offset = tilingData->s4Offset + maxWorkspaceSize; +} + +void MlaPreprocessTiling::SetTilingKey() +{ + uint64_t tilingKey = (static_cast(opParam.inDtype == at::kBFloat16)) << 8; + + tilingKey |= static_cast(opParam.cacheMode); + tilingKey |= (static_cast(opParam.quantMode) << 3); + + tilingData->tilingKey = tilingKey; +} + +void MlaPreprocessTiling::Init() +{ + tilingData->numCore = platformInfo.coreNumAic; + tilingData->n = opParam.N; + + bool deqOnTheFly = false; + if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + deqOnTheFly = true; + } + + PpMatmulTilingApi mm1TilingApi(platformInfo, + 1, // numBatch + opParam.N, // m + HIDDEN_STRATE, // k + HIDDEN_STRATE_MM, // n + false, // transA + true, // transB + true, // enDequant + deqOnTheFly); // in bf16.cce? + mm1TilingApi.GetTilingData(tilingData->mm1); + + PpMatmulTilingApi mm2TilingApi(platformInfo, + 1, // numBatch + opParam.N, // m + HIDDEN_STRATE_RMS, // k + opParam.headNum * HIDDEN_STRATE_ROPE, // n + false, // transA + true, // transB + true, // enDequant + deqOnTheFly); // in bf16.cce? + mm2TilingApi.GetTilingData(tilingData->mm2); + + PpMatmulTilingApi mm3TilingApi(platformInfo, + opParam.headNum, // numBatch + opParam.N, // m + CONST_128, // k + CONCAT_SIZE, // n + false, // transA + false, // transB + false, // enDequant + deqOnTheFly); // in bf16.cce? + mm3TilingApi.GetTilingData(tilingData->mm3); + + RmsNormQuantTiling(); + RopeConcatTiling(); + EinSumQuantTiling(); + + SetMlapoWorkSpace(); + SetTilingKey(); + + return; +} + +std::unordered_map cache_mode_map = { + {"krope_ctkv", 1}, {"int8_nzcache", 2}, {"nzcache", 3}}; + +std::unordered_map quant_mode_map = { + {"per_tensor_quant_asymm", 0}, + {"per_token_quant_symm", 1}, +}; + +template +inline int get_op_mode(const MapType &mode_map, c10::optional mode_opt, c10::string_view default_mode, + const char *mode_name) +{ + c10::string_view mode_str = mode_opt.value_or(default_mode); + auto it = mode_map.find(mode_str); + TORCH_CHECK(it != mode_map.end(), "Unsupported ", mode_name, " value: '", mode_str, "'"); + return it->second; +} + +// std::tuple mla_preprocess( +// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv, +// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, +// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, +// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, +// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, +// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, +// const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, +// c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, +// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) +std::tuple mla_preprocess_tiling( + const at::Tensor &hiddenState, + const at::Tensor &wuk, + c10::optional cache_mode, + c10::optional quant_mode +) +{ + auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode"); + auto quantMode = get_op_mode(quant_mode_map, quant_mode, "per_token_quant_symm", "quant_mode"); + + platform_ascendc::PlatformAscendC *platformAscendC = platform_ascendc::PlatformAscendCManager::GetInstance(); + + struct PlatformInfo platformInfo; + platformInfo.coreNum = platformAscendC->GetCoreNum(); + platformInfo.coreNumAic = platformAscendC->GetCoreNumAic(); + platformInfo.coreNumAiv = platformAscendC->GetCoreNumAiv(); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformInfo.ubSize); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L1, platformInfo.l1Size); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L2, platformInfo.l2Size); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, platformInfo.l0aSize); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, platformInfo.l0bSize); + platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, platformInfo.l0cSize); + + int32_t N = hiddenState.sizes()[0]; + int32_t headNum = wuk.sizes()[0]; + + OpParam opParam; + opParam.N = N; + opParam.headNum = headNum; + opParam.cacheMode = static_cast(cacheMode); + opParam.quantMode = static_cast(quantMode); + opParam.inDtype = hiddenState.options().dtype(); + + MlaTilingData tilingData; + MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); + + mlaTiling.Init(); + uint32_t blockDim = platformInfo.coreNumAic; + + // workspace + uint64_t system_workspace_size = static_cast(platformAscendC->GetLibApiWorkSpaceSize()); + uint64_t workspace_size = system_workspace_size + tilingData.userWorkspaceSize; + auto options = at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()); + auto workspace_tensor = at::empty({static_cast(workspace_size)}, options); + + // tiling + int32_t bIndex = N - 1; + uint32_t tilingSize = sizeof(MlaTilingData); + static auto global_tiling_data = + at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS}, + at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device())); + if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) { + aclrtMemcpy(global_tiling_data.data_ptr() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize, + ACL_MEMCPY_HOST_TO_DEVICE); + } else { + // Handle the case where bIndex is out of range + TORCH_CHECK(false, "bIndex is out of range: ", bIndex); + } + at::Tensor tiling = at::from_blob( + global_tiling_data.data_ptr() + (tilingSize * bIndex), + tilingSize, + at::kByte); + + return std::make_tuple(workspace_tensor, tiling, blockDim); +} + +} // namespace npu_kernel diff --git a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h new file mode 100644 index 0000000000..aab1f3a7a9 --- /dev/null +++ b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h @@ -0,0 +1,95 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#ifndef MLAPREPROCESS_TILING_H +#define MLAPREPROCESS_TILING_H + +#include + +struct PpMatmulTilingData { + uint32_t numBatch{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t mLoop{0}; + uint32_t kLoop{0}; + uint32_t nLoop{0}; + uint32_t coreLoop{0}; + uint32_t swizzleCount{0}; + uint32_t swizzleDirect{0}; + uint32_t enShuffleK{0}; + uint32_t blockDim{0}; + uint32_t enLoadAllAmat{0}; + uint32_t b0matPingPongBufferLen{0}; +}; + +struct MlaTilingData { + uint32_t tilingKey{0}; + uint64_t userWorkspaceSize{0}; + uint64_t s1Offset{0}; + uint64_t s2Offset{0}; + uint64_t s3Offset{0}; + uint64_t s4Offset{0}; + uint64_t s5Offset{0}; + + uint32_t numCore{0}; + uint32_t n{0}; + uint32_t perTaskNum{0}; + uint32_t resTaskNum{0}; + + PpMatmulTilingData mm1; + PpMatmulTilingData mm2; + PpMatmulTilingData mm3; + // rms1 + uint32_t rmsNumCore1{0}; + uint32_t rmsNumCol1{0}; + uint32_t rmsNumRow1{0}; + uint32_t rmsQuantMin1{0}; + // rms2 + uint32_t rmsNumCore2{0}; + uint32_t rmsNumCol2{0}; + uint32_t rmsNumRow2{0}; + uint32_t rmsQuantMin2{0}; + + uint32_t hiddenSizeQ{0}; + uint32_t headNumQ{0}; + uint32_t headDim{0}; + uint32_t concatSize{0}; + uint32_t rotaryCoeff{0}; + uint32_t ntokens{0}; + uint32_t realCore{0}; + uint32_t nlCoreRun{0}; + uint32_t lCoreRun{0}; + uint32_t maxNPerLoopForUb{0}; + uint32_t preCoreLoopTime{0}; + uint32_t preCoreLoopNLast{0}; + uint32_t lastCoreLoopTime{0}; + uint32_t lastCoreLoopNLast{0}; + + // EinSumQuant + uint32_t esqFrontCore{0}; + uint32_t esqTailCore{0}; + uint32_t esqFrontCoreBatch{0}; + uint32_t esqTailCoreBatch{0}; + uint32_t esqHeadNum{0}; + uint32_t esqColNum{0}; + uint32_t esqUbHeadLoop{0}; + uint32_t esqHeadPerLoop{0}; + uint32_t esqHeadTail{0}; + uint32_t esqColLoop{0}; + uint32_t esqColTail{0}; +}; + +#endif // MLAPREPROCESS_TILING_H diff --git a/csrc/mla_preprocess/op_kernel/kernel/common.h b/csrc/mla_preprocess/op_kernel/kernel/common.h new file mode 100644 index 0000000000..d379e9b00e --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/common.h @@ -0,0 +1,25 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_COMMON_H +#define INCLUDE_COMMON_H + +#define CONST_2 2 + +#define SET_FLAG(trigger, waiter, e) AscendC::SetFlag((e)) +#define WAIT_FLAG(trigger, waiter, e) AscendC::WaitFlag((e)) +#define PIPE_BARRIER(pipe) AscendC::PipeBarrier() + +#ifndef __force_inline__ +#define __force_inline__ inline __attribute__((always_inline)) +#endif + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/common_func.h b/csrc/mla_preprocess/op_kernel/kernel/common_func.h new file mode 100644 index 0000000000..683d1a1f81 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/common_func.h @@ -0,0 +1,121 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_COMMON_FUNC_H +#define INCLUDE_COMMON_FUNC_H + +#include +#include + +#ifdef __CCE_KT_TEST__ +#include "stub_def.h" +#include "stub_fun.h" +#else +#include "kernel_macros.h" +#endif + +template +inline __aicore__ T RoundUp(const T val) +{ + static_assert(ALIGN != 0, "align must not be zero"); + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + T align = ALIGN; + if (val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; +} + +template +inline __aicore__ T RoundUp(const T val, const T align) +{ + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (align == 0 || val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; +} + +template +inline __aicore__ T CeilDiv(const T dividend) +{ + static_assert(DIVISOR != 0, "align must not be zero"); + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + T divisor = DIVISOR; + if (dividend + divisor - 1 < dividend) { + return dividend; + } + return (dividend + divisor - 1) / divisor; +} + +template +constexpr T T_MAX = std::numeric_limits::max(); + +template +inline __aicore__ T CeilDiv(const T dividend, const T divisor) +{ + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (divisor == 0 || dividend + divisor - 1 < dividend) { + return T_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +template +__aicore__ inline T Min(const T lhs, const T rhs) +{ + return lhs < rhs ? lhs : rhs; +} + +template +__aicore__ __attribute__((always_inline)) inline uint32_t BlockSize() +{ + return 32 / sizeof(Dtype); +} + +template +__aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize() +{ + return 512 / sizeof(Dtype); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num) +{ + return (num + BlockSize() - 1) / BlockSize() * BlockSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num) +{ + return (num + BlockSize() - 1) / BlockSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num) +{ + return (num + MatrixSize() - 1) / MatrixSize() * MatrixSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num) +{ + return (num + MatrixSize() - 1) / MatrixSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize() +{ + return 32 * 1024 / sizeof(Dtype); +} + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/hardware.h b/csrc/mla_preprocess/op_kernel/kernel/hardware.h new file mode 100644 index 0000000000..1370710b4c --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/hardware.h @@ -0,0 +1,36 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_HARDWARE_H +#define INCLUDE_HARDWARE_H + +enum class ArchType { ASCEND_V220, ASCEND_V200, ASCEND_M200 }; + +template +struct HardwareInfo { + static uint32_t const l2BW = 5; + static uint32_t const hbmBW = 1; + static uint32_t const supportMix = 0; + static uint32_t const l1Size = 512 * 1024; + static uint32_t const l0ASize = 64 * 1024; + static uint32_t const l0BSize = 64 * 1024; + static uint32_t const l0CSize = 128 * 1024; + static uint32_t const l2Size = 192 * 1024 * 1024; + static uint32_t const biasSize = 1024; + static uint32_t const fixBufSize = 7 * 1024; + static uint32_t const ubSize = 192 * 1024; + static uint32_t const fractalSize = 512; + static uint32_t const l1l0BlockSize = 32; + static uint32_t const btBlockSize = 64; + static uint32_t const fbBlockSize = 128; +}; + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterator.h b/csrc/mla_preprocess/op_kernel/kernel/iterator.h new file mode 100644 index 0000000000..3e728930e7 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterator.h @@ -0,0 +1,92 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_ITERTOR_H +#define INCLUDE_ITERTOR_H + +#include "common_func.h" +#include "hardware.h" +#include "kernel_operator.h" +#include "layout.h" +#include "mem.h" + +///////////////////////////////////////////////////// +// gm_to_l1 +///////////////////////////////////////////////////// +template +struct gm_to_l1 { + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual, + uint32_t dTileCeil, uint32_t dVal) {}; +}; + +///////////////////////////////////////////////////// +// l1_to_l0_a +///////////////////////////////////////////////////// +template +struct l1_to_l0_a { + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, uint32_t kPartCeil, uint32_t mSrcStride, uint32_t kSrcStride, + uint32_t mDstStride, uint32_t kDstStride) {}; +}; + +///////////////////////////////////////////////////// +// l1_to_l0_b +///////////////////////////////////////////////////// +template +struct l1_to_l0_b { + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, uint32_t kPartCeil, uint32_t nSrcStride, uint32_t kSrcStride, + uint32_t nDstStride, uint32_t kDstStride) {}; +}; + +///////////////////////////////////////////////////// +// l0c_to_gm +///////////////////////////////////////////////////// +template +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, uint32_t nTileActual, uint32_t mTileCeil, uint32_t nActual, + uint8_t unitFlag = 0) {}; +}; + +///////////////////////////////////////////////////// +// l0c_to_l1 +///////////////////////////////////////////////////// +template +struct l0c_to_l1 { + __aicore__ l0c_to_l1(AscendC::LocalTensor l1Tensor, AscendC::LocalTensor l0cTensor, + AscendC::LocalTensor deqTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t mTileCeil, uint32_t nActual) {}; +}; + +template +struct l1_to_bt { + __aicore__ l1_to_bt(uint64_t dst, const AscendC::LocalTensor &src, uint16_t convControl, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) {}; +}; + +template +struct l1_to_fb { + __aicore__ l1_to_fb(AscendC::LocalTensor &dst, AscendC::LocalTensor &src, uint16_t burstNum, + uint16_t burstLen, uint16_t srcGap, uint16_t dstGap) {}; +}; + +#include "iterators/gm_to_l1_iterator.inc" +#include "iterators/gm_to_ub_iterator.inc" +#include "iterators/l0c_to_gm_iterator.inc" +#include "iterators/l0c_to_l1_iterator.inc" +#include "iterators/l0c_to_ub_iterator.inc" +#include "iterators/l1_to_bt_iterator.inc" +#include "iterators/l1_to_fb_iterator.inc" +#include "iterators/l1_to_l0_iterator.inc" +#include "iterators/l1_to_ub_iterator.inc" +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_l1_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_l1_iterator.inc new file mode 100644 index 0000000000..0d201642f8 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_l1_iterator.inc @@ -0,0 +1,162 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +// Partial specialization for V220, ND_in, ND_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + AscendC::DataCopy(l1Tensor, // dst + gmTensor, // src + AscendC::DataCopyParams(1, // nBurst + CeilDiv(nTileActual * dTileActual), // lenBurst + 0, // srcGap + 0)); // dstGap + }; +}; + +// Partial specialization for NZ_in, NZ_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t STRIDE_LIMIT = 65536; + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + uint64_t srcStride = nVal - nTileCeil; + if (srcStride < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, // dst + gmTensor, // src + AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst + nTileCeil, // lenBurst + srcStride, // srcGap + 0)); // dstGap + } else { + for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; i++) { + uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE; + uint64_t srcOffset = i * nVal * BLOCK_SIZE; + AscendC::DataCopy(l1Tensor[dstOffset], // dst + gmTensor[srcOffset], // src + AscendC::DataCopyParams(1, // nBurst + nTileCeil, // lenBurst + 0, // srcGap + 0)); // dstGap + } + } + }; +}; + +// Partial specialization for V220, ND_in, ND_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t STRIDE_LIMIT = 65536; + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + if (dVal < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + nTileActual, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + dVal, // srcDValue + nTileCeil, // dstNzC0Stride + 1, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } else { + for (uint32_t i = 0; i < nTileActual; i++) { + AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE], + gmTensor[i * dVal], + AscendC::Nd2NzParams(1, // ndNum + 1, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + 0, // srcDValue + nTileCeil, // dstNzC0Stride + 0, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } + } + }; +}; + +// Partial specialization for V220, ND_in, NZ_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t STRIDE_LIMIT = 65536; + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + if (dVal < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + nTileActual, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + dVal, // srcDValue + nTileCeil, // dstNzC0Stride + 1, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } else { + for (uint32_t i = 0; i < nTileActual; ++i) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + 1, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + 0, // srcDValue + nTileCeil, // dstNzC0Stride + 0, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } + } + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_ub_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_ub_iterator.inc new file mode 100644 index 0000000000..9fdf176055 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/gm_to_ub_iterator.inc @@ -0,0 +1,89 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +template struct gm_to_ub { + __aicore__ inline gm_to_ub(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template struct gm_to_ub_align { + __aicore__ inline gm_to_ub_align(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum, + uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap) + { + AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0), + AscendC::DataCopyPadExtParams(false, leftPaddingNum, rightPaddingNum, 0)); + }; +}; + +template struct ub_to_ub { + __aicore__ inline ub_to_ub(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template +struct ub_to_gm { + __aicore__ inline ub_to_gm(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template struct ub_to_gm { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ ub_to_gm(AscendC::GlobalTensor gmTensor, AscendC::LocalTensor ubTensor, + uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual, + uint32_t dTileCeil, uint32_t dVal) + { + constexpr uint32_t STRIDE_LIMIT = 65536; + uint64_t dstStride = nVal - nTileCeil; + if (dstStride < STRIDE_LIMIT) { + AscendC::DataCopy(gmTensor, // dst + ubTensor, // src + AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst + nTileCeil, // lenBurst + 0, // srcGap + dstStride)); // dstGap + } else { + for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; ++i) { + uint64_t dstOffset = i * nVal * BLOCK_SIZE; + uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE; + AscendC::DataCopy(gmTensor[dstOffset], // dst + ubTensor[srcOffset], // src + AscendC::DataCopyParams(1, // nBurst + nTileCeil, // lenBurst + 0, // srcGap + 0)); // dstGap + } + } + }; +}; + +template struct ub_to_gm_align { + __aicore__ inline ub_to_gm_align(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum, + uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap) + { + AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0)); + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_gm_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_gm_iterator.inc new file mode 100644 index 0000000000..8ef8d93dc0 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_gm_iterator.inc @@ -0,0 +1,228 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +constexpr uint32_t BLOCK_NUM = 16; +constexpr uint32_t BLOCK_SIZE_INT8 = 32; + +template <> +struct l0c_to_gm { + /** + * @brief Copy data from L0C buffer to global memory, partial specialized for + * + * @param gmTensor the destination tensor on global memory, which is stored in ND format. + * @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format. + * @param mTileActual the m-direction size of the matrix in L0C buffer. + * @param nTileActual the n-direction size of the matrix in L0C buffer. + * @param srcStride the source stride between the adjacent fractal matrix along n-direction in unit of C0_SIZE. + * @param dstStride the leading dimension of the destination matrix in unit of element. + */ + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322F16; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::F322F16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::VDEQF16; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::VDEQF16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +#ifdef __DAV_C220_CUBE__ + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322BF16; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::F322BF16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +#endif + + +// Partial specialization ND, float +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::NoQuant; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::NoQuant}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322F16; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride - (nTileActual * sizeof(half) / sizeof(float))); + intriParams.quantParams = {QuantMode_t::F322F16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride, + uint8_t unitFlag = 0) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::NoQuant; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::VDEQF16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_l1_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_l1_iterator.inc new file mode 100644 index 0000000000..ddda8f6dcc --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_l1_iterator.inc @@ -0,0 +1,42 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" +///////////////////////////////////////////////////// +// l0c_to_l1 +///////////////////////////////////////////////////// + +// Partial specialization ZN, half, int32_t +template +struct l0c_to_l1 { + using ElementOut = half; + using ElementIn = int32_t; + __aicore__ l0c_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::LocalTensor l0cTensor, + AscendC::LocalTensor deqTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t mTileCeil, + uint32_t nActual) + { + constexpr uint32_t BLOCK_NUM = 16; + constexpr uint32_t BLOCK_SIZE = 32; + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE), + 0, + mTileCeil - static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) * + sizeof(ElementOut) / sizeof(ElementIn)); + intriParams.nz2ndParams = {false, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::VDEQF16}; + AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams); + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_ub_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_ub_iterator.inc new file mode 100644 index 0000000000..129e7e59bf --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l0c_to_ub_iterator.inc @@ -0,0 +1,71 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l0c_to_ub +///////////////////////////////////////////////////// + +// Partial specialization ZN, half, int32_t +template +struct l0c_to_ub { + __aicore__ l0c_to_ub(AscendC::LocalTensor ubTensor, + AscendC::LocalTensor l0cTensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + constexpr auto mode = + MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR; + AscendC::DataCopy(ubTensor, + l0cTensor, + AscendC::DataCopyParams(nBurst, // count + lenBurst, // len + srcStride, // srcStrideIn + dstStride), // dstStrideIn + AscendC::DataCopyEnhancedParams(mode, // blockModeIn + AscendC::DeqScale::DEQ_NONE, // deqScaleIn + 0, // deqValueIn + 0, // sidStoreModeIn + false, // isReluIn + pad_t::PAD_NONE, // padModeIn + 0) // padValueIn + ); + }; +}; + +template +struct l0c_to_ub { + __aicore__ l0c_to_ub(AscendC::LocalTensor ubTensor, + AscendC::LocalTensor l0cTensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(ubTensor, + l0cTensor, + AscendC::DataCopyParams(nBurst, // count + lenBurst, // len + srcStride, // srcStrideIn + dstStride), // dstStrideIn + AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn + AscendC::DeqScale::VDEQ16, // deqScaleIn + 0, // deqValueIn + 0, // sidStoreModeIn + false, // isReluIn + pad_t::PAD_NONE, // padModeIn + 0) // padValueIn + ); + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_bt_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_bt_iterator.inc new file mode 100644 index 0000000000..161468b597 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_bt_iterator.inc @@ -0,0 +1,39 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_bt +///////////////////////////////////////////////////// + +// Partial specialization for V220 +template +struct l1_to_bt { + __aicore__ l1_to_bt(uint64_t dst, + const AscendC::LocalTensor &src, + uint16_t convControl, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcGap, + uint16_t dstGap) + { + AscendC::LocalTensor dstTensor; + dstTensor.InitBuffer(dst, nBurst * lenBurst); + dstTensor.address_.logicPos = static_cast(AscendC::TPosition::C2); + AscendC::DataCopy(dstTensor, + src, + AscendC::DataCopyParams(nBurst, // nBurst + lenBurst, // lenBurst + srcGap, // srcGap + dstGap)); // dstGap + } +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_fb_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_fb_iterator.inc new file mode 100644 index 0000000000..0e8c7af891 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_fb_iterator.inc @@ -0,0 +1,36 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_fb +///////////////////////////////////////////////////// + +// Partial specialization for V220 +template +struct l1_to_fb { + __aicore__ l1_to_fb(AscendC::LocalTensor &dst, + AscendC::LocalTensor &src, + uint16_t burstNum, + uint16_t burstLen, + uint16_t srcGap, + uint16_t dstGap) + { + dst.address_.logicPos = static_cast(AscendC::TPosition::C2PIPE2GM); + AscendC::DataCopy(dst, + src, + AscendC::DataCopyParams(burstNum, // nBurst + burstLen, // lenBurst + srcGap, // srcGap + dstGap)); // dstGap); + } +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_l0_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_l0_iterator.inc new file mode 100644 index 0000000000..4cd234f16e --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_l0_iterator.inc @@ -0,0 +1,310 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_l0_a +///////////////////////////////////////////////////// + +// Partial specialization for vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + AscendC::LoadData(l0Tensor, + l1Tensor, + AscendC::LoadData2dParams(0, // baseIdx + kPartCeil, // repeat + kSrcStride, // srcStride + 0, // sid + kDstStride, // dstStride + IsTransPose, // transpose + 0)); // addrCalMode + }; +}; + +// Partial specialization for no transpose, not vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) { + AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], // dst + l1Tensor[i * mSrcStride * FRACTAL_SIZE], // src + AscendC::LoadData2dParams(0, // baseIdx + static_cast(kPartCeil / BLOCK_SIZE), // repeat + kSrcStride, // srcStride + 0, // sid + kDstStride - 1, // dstStride + false, // transpose + 0)); // addrCalMode + } + }; +}; + +// Partial specialization for transpose, not vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) { + AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], + l1Tensor[i * mSrcStride * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, + static_cast(kPartCeil / BLOCK_NUM_PER_FRACTAL), + kSrcStride, + 0, + kDstStride - 1, + true, + 0)); + } + }; +}; + +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + // 16 * 32 + static constexpr uint32_t ROW_BLOCK_SIZE = 16; + static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) { + AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil], + l1Tensor[i * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, + static_cast(kPartCeil / COL_BLOCK_SIZE), + mTileCeil / ROW_BLOCK_SIZE, + 0, + 0, + false, + 0)); + } + }; +}; + +template <> +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32 + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 512 + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; // 16 + static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2; + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint64_t i = 0; i < mTileCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) { + AscendC::LoadDataWithTranspose( + l0Tensor[i * mDstStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // dstLocalTensor + l1Tensor[i * mSrcStride * FRACTAL_SIZE], // srcLocalTensor + AscendC::LoadData2dTransposeParams(0, // baseIdx + static_cast(CeilDiv(kPartCeil)), // repeat + kSrcStride, // srcStride + 0, // dstGap + mDstStride - 1)); // dstFracGap + } + } +}; + +///////////////////////////////////////////////////// +// l1_to_l0_b +///////////////////////////////////////////////////// + +// Partial specialization for vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + AscendC::LoadData( + l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0)); + }; +}; + +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + using DataType = int8_t; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) { + AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE], + l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE], + AscendC::LoadData2dTransposeParams(0, // startIndexIn + kPartCeil / BLOCK_SIZE, // repeatTimesIn + nTileCeil / BLOCK_SIZE, // srcStrideIn + 1, // dstGapIn + 0, // dstfracGapIn + 0) // addrModeIn + ); + } + }; +}; + +// Partial specialization for no transpose, not vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) { + AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE], + l1Tensor[i * kSrcStride * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, // baseIdx + static_cast(nTileCeil / BLOCK_SIZE), // repeat + nSrcStride, // srcStride + 0, // sid + nDstStride - 1, // dstStride + true, // transpose + 0)); // addrCalMode + } + }; +}; + +// Partial specialization for transpose, not vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + AscendC::LoadData( + l0Tensor, + l1Tensor, + AscendC::LoadData2dParams(0, // baseIdx + static_cast(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat + 1, // srcStride + 0, // sid + 0, // dstStride + false, // transpose + 0)); // addr_cal_mode_t + }; +}; + +template <> +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32 + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 16 + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2; + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + for (uint64_t i = 0; i < kPartCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) { + AscendC::LoadDataWithTranspose( + l0Tensor[i * kDstStride * FRACTAL_SIZE], // dstLocalTensor + l1Tensor[i * kSrcStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // srcLocalTensor + AscendC::LoadData2dTransposeParams(0, // baseIdx + static_cast(CeilDiv(nTileCeil)), // repeat + nSrcStride / NUM_FRACTAL_PER_ITER, // srcStride + 1, // dstGap + 0)); // dstFracGap + } + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_ub_iterator.inc b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_ub_iterator.inc new file mode 100644 index 0000000000..e7e075ac7c --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/iterators/l1_to_ub_iterator.inc @@ -0,0 +1,44 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_ub +///////////////////////////////////////////////////// +template +struct l1_to_ub { + __aicore__ l1_to_ub(AscendC::LocalTensor ubTensor, + AscendC::LocalTensor l1Tensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +///////////////////////////////////////////////////// +// ub_to_l1 +///////////////////////////////////////////////////// +template +struct ub_to_l1 { + __aicore__ ub_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::LocalTensor ubTensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; diff --git a/csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h b/csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h new file mode 100644 index 0000000000..25c97cf003 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h @@ -0,0 +1,395 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H +#define ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H +#include "kernel_operator.h" + +using AscendC::HardEvent; + +__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y) +{ + return y == 0 ? 0 : ((x + y - 1) / y); +} + +__aicore__ inline uint32_t RoundUp(uint32_t x, uint32_t y = 16) +{ + return (x + y - 1) / y * y; +} + +__aicore__ inline uint32_t Min(uint32_t x, uint32_t y) +{ + return x < y ? x : y; +} + +__aicore__ inline uint32_t Max(uint32_t x, uint32_t y) +{ + return x > y ? x : y; +} + +template +__aicore__ inline void CopyIn(const AscendC::GlobalTensor &gm, Q &queue, uint64_t offset, uint32_t count) +{ + AscendC::LocalTensor local = queue.template AllocTensor(); + DataCopy(local, gm[offset], count); + queue.EnQue(local); +} + +template +__aicore__ inline void CopyOut(const AscendC::GlobalTensor &gm, Q &queue, uint64_t offset, uint32_t count) +{ + AscendC::LocalTensor local = queue.template DeQue(); + DataCopy(gm[offset], local, count); + queue.FreeTensor(local); +} + +template +__aicore__ inline void CastFrom16To32(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + uint32_t count) +{ + Cast(out, in, AscendC::RoundMode::CAST_NONE, count); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void CastFrom32To16(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + uint32_t count) +{ + if constexpr (AscendC::IsSameType::value) { + Cast(out, in, AscendC::RoundMode::CAST_NONE, + count); // 310p cast fp32->half 只能用CAST_NONE,这里拉齐310p和910b + } else { // bf16 + Cast(out, in, AscendC::RoundMode::CAST_RINT, count); + } + AscendC::PipeBarrier(); +} + +__aicore__ inline void CastFromF16ToI8(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + half quantMin, uint32_t count) +{ + Maxs(in, in, quantMin, count); + AscendC::PipeBarrier(); + Mins(in, in, (half)127, count); // 127: limit + AscendC::PipeBarrier(); +#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220) + Cast(out, in, AscendC::RoundMode::CAST_RINT, count); +#else + Cast(out, in, AscendC::RoundMode::CAST_NONE, count); +#endif + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void CopyInAndCastF32(const AscendC::LocalTensor &out, const AscendC::GlobalTensor &gm, + Q &queue, uint64_t offset, uint32_t count) +{ + CopyIn(gm, queue, offset, count); + AscendC::LocalTensor local = queue.template DeQue(); + Cast(out, local, AscendC::RoundMode::CAST_NONE, count); + queue.FreeTensor(local); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void Cast16AndCopyOut(const AscendC::LocalTensor &in, const AscendC::GlobalTensor &gm, + Q &queue, uint64_t offset, uint32_t count) +{ + AscendC::LocalTensor local = queue.template AllocTensor(); + CastFrom32To16(local, in, count); + queue.EnQue(local); + CopyOut(gm, queue, offset, count); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline T ComputeSum(const AscendC::LocalTensor &in, const AscendC::LocalTensor &tmp, + const AscendC::LocalTensor &workLocal, uint32_t count) +{ +#if __CCE_AICORE__ == 100 + float sum = 0; + int64_t elementNumPerRep = AscendC::ONE_REPEAT_BYTE_SIZE / sizeof(T); + AscendC::LocalTensor src = in; + while (count > elementNumPerRep) { + int64_t repeatTimes = count / elementNumPerRep; + int64_t tailCount = count % elementNumPerRep; + int64_t bodyCount = repeatTimes * elementNumPerRep; + if (repeatTimes > 0) { + AscendC::AscendCUtils::SetMask(elementNumPerRep); + vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)src.GetPhyAddr(), repeatTimes, 1, 1, 8); + AscendC::SetFlag(EVENT_ID0); // PipeBarrier(PIPE_V)? + AscendC::WaitFlag(EVENT_ID0); + } + + if (tailCount != 0) { + AscendC::AscendCUtils::SetMask(tailCount); + vcadd((__ubuf__ T *)tmp[bodyCount].GetPhyAddr(), (__ubuf__ T *)src[bodyCount].GetPhyAddr(), 1, 1, 1, 8); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + sum += tmp.GetValue(bodyCount); + } + + count = repeatTimes; + src = tmp; + } + + if (count > 1) { + AscendC::AscendCUtils::SetMask(count); + vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)tmp.GetPhyAddr(), 1, 1, 1, 8); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + sum += tmp.GetValue(0); + return sum; +#else + ReduceSum(tmp, in, workLocal, count); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + return tmp.GetValue(0); +#endif +} + +__aicore__ inline float ComputeSliceSquareSum(const AscendC::LocalTensor &in, + const AscendC::LocalTensor &tmp, + const AscendC::LocalTensor &workLocal, uint32_t count) +{ + Mul(tmp, in, in, count); + AscendC::PipeBarrier(); + return ComputeSum(tmp, tmp, workLocal, count); +} +template +__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + float rms, const AscendC::LocalTensor &gamma, uint32_t count, + uint32_t precisionMode, uint32_t gemmaMode, + const AscendC::LocalTensor &tmp) +{ + float value = 1.0; + Duplicate(tmp, rms, count); + AscendC::PipeBarrier(); + Div(tmp, in, tmp, count); + AscendC::PipeBarrier(); + + if (precisionMode == 0) { + CastFrom16To32(in, gamma, count); + AscendC::PipeBarrier(); + if (gemmaMode == 1) { + Adds(in, in, value, count); + AscendC::PipeBarrier(); + } + Mul(in, in, tmp, count); + AscendC::PipeBarrier(); + CastFrom32To16(out, in, count); + return; + } + if constexpr (std::is_same::value) { + CastFrom32To16(out, tmp, count); + Mul(out, out, gamma, count); + AscendC::PipeBarrier(); + } +} + +template +__aicore__ inline void CastGAndIsGemmaMode(const AscendC::LocalTensor &out, const AscendC::LocalTensor &gamma, + uint32_t count) +{ + Cast(out, gamma, AscendC::RoundMode::CAST_NONE, count); + AscendC::PipeBarrier(); + float value = 1.0; + if constexpr (gemmaMode == 1) { + Adds(out, out, value, count); + AscendC::PipeBarrier(); + } +} + +template +__aicore__ inline void ComputeRmsNormFast(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + float rms, const AscendC::LocalTensor &gamma, uint32_t count, + const AscendC::LocalTensor &tmp, + const AscendC::LocalTensor &fp32_g) +{ + float value = 1.0; + Duplicate(tmp, rms, count); + AscendC::PipeBarrier(); + Div(tmp, in, tmp, count); + AscendC::PipeBarrier(); + if constexpr (precisionMode == 0) { + Mul(in, fp32_g, tmp, count); + AscendC::PipeBarrier(); + CastFrom32To16(out, in, count); + return; + } + if constexpr (std::is_same::value) { + CastFrom32To16(out, tmp, count); + Mul(out, out, gamma, count); + AscendC::PipeBarrier(); + } +} + +template +__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + float rms, const AscendC::LocalTensor &gamma, + const AscendC::LocalTensor &beta, const AscendC::LocalTensor &tmp, + uint32_t count) +{ + Duplicate(tmp, rms, count); + AscendC::PipeBarrier(); + Div(out, in, tmp, count); + AscendC::PipeBarrier(); + CastFrom16To32(tmp, gamma, count); + Mul(out, out, tmp, count); + AscendC::PipeBarrier(); + if constexpr (WITH_BETA) { + CastFrom16To32(tmp, beta, count); + Add(out, out, tmp, count); + AscendC::PipeBarrier(); + } +} + +template +__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + float reciprocal_of_rms, const AscendC::LocalTensor &gamma, + const AscendC::LocalTensor &tmp, const AscendC::LocalTensor &res_out, + uint32_t count) +{ + Duplicate(tmp, reciprocal_of_rms, count); + AscendC::PipeBarrier(); + Mul(out, in, tmp, count); + AscendC::PipeBarrier(); + CastFrom16To32(tmp, gamma, count); + Mul(out, out, tmp, count); + AscendC::PipeBarrier(); + CastFrom32To16(res_out, out, count); +} + +template +__aicore__ inline void ComputeResidualAdd(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + const AscendC::LocalTensor &resIn, uint32_t count) +{ + Add(out, in, resIn, count); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void ComputeMean(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, T aveNum, + uint32_t count) +{ + Duplicate(out, aveNum, count); + AscendC::PipeBarrier(); + Mul(out, in, out, count); + AscendC::PipeBarrier(); + T sum = ComputeSum(out, out, out, count); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + Duplicate(out, sum, count); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void ComputeLayerNorm(const AscendC::LocalTensor &out, const AscendC::LocalTensor &in, + const AscendC::LocalTensor &mean, float eps, float aveNum, + const AscendC::LocalTensor &gamma, const AscendC::LocalTensor &beta, + uint32_t count) +{ + Sub(in, in, mean, count); + AscendC::PipeBarrier(); + Mul(out, in, in, count); + AscendC::PipeBarrier(); + Muls(out, out, aveNum, count); + AscendC::PipeBarrier(); + ReduceSum(out, out, out, count); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + float var = out.GetValue(0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + Duplicate(out, var, count); + AscendC::PipeBarrier(); + Adds(out, out, eps, count); + AscendC::PipeBarrier(); + Sqrt(out, out, count); + AscendC::PipeBarrier(); + + Div(out, in, out, count); + AscendC::PipeBarrier(); + + Cast(in, gamma, AscendC::RoundMode::CAST_NONE, count); + AscendC::PipeBarrier(); + Mul(out, out, in, count); + AscendC::PipeBarrier(); + Cast(in, beta, AscendC::RoundMode::CAST_NONE, count); + AscendC::PipeBarrier(); + Add(out, out, in, count); + AscendC::PipeBarrier(); +} + +__aicore__ inline void ComputeFp16ToI8Quant(const AscendC::LocalTensor &out, + const AscendC::LocalTensor &in, const AscendC::LocalTensor &tmp, + half scale, half offset, half quantMin, uint32_t count) +{ + Muls(tmp, in, scale, count); + AscendC::PipeBarrier(); + Adds(tmp, tmp, offset, count); + AscendC::PipeBarrier(); + CastFromF16ToI8(out, tmp, quantMin, count); +} + +__aicore__ inline void ComputeFp32ToI8Quant(const AscendC::LocalTensor &out, + const AscendC::LocalTensor &in, + const AscendC::LocalTensor &tmp, half scale, half offset, + half quantMin, uint32_t count) +{ + CastFrom32To16(tmp, in, count); + AscendC::PipeBarrier(); + ComputeFp16ToI8Quant(out, tmp, tmp, scale, offset, quantMin, count); +} + +__aicore__ inline void ComputeHighPrecisionFp32ToI8Quant(const AscendC::LocalTensor &out, + const AscendC::LocalTensor &in, + const AscendC::LocalTensor &tmp, float scale, + float offset, half quantMin, uint32_t count) +{ + Muls(in, in, scale, count); + AscendC::PipeBarrier(); + Adds(in, in, offset, count); + AscendC::PipeBarrier(); + CastFrom32To16(tmp, in, count); + CastFromF16ToI8(out, tmp, quantMin, count); +} + +__aicore__ inline void CopyGmTilingToUb(__ubuf__ uint8_t *&tilingInUb, const __gm__ uint8_t *tilingInGm, + size_t tilingSize, AscendC::TPipe *pipe) +{ + uint32_t roundTilingSize = RoundUp(tilingSize, 32); + AscendC::TBuf tilingBuf; + AscendC::GlobalTensor tilingGm; + + tilingGm.SetGlobalBuffer((__gm__ uint8_t *)tilingInGm); + pipe->InitBuffer(tilingBuf, roundTilingSize); + + AscendC::LocalTensor tilingUb = tilingBuf.Get(); + AscendC::DataCopy(tilingUb, tilingGm, roundTilingSize); + + tilingInUb = (__ubuf__ uint8_t *)tilingUb.GetPhyAddr(); +} + +template +__aicore__ inline uint32_t GetReduceSumWorkLocalSize(uint32_t sliceSize) +{ + uint32_t elementsPerBlock = 32 / sizeof(T); + uint32_t elementsPerRepeat = 256 / sizeof(T); + + uint32_t firstMaxRepeat = sliceSize < elementsPerRepeat ? 1u : (sliceSize / elementsPerRepeat); + uint32_t iter1OutputCount = firstMaxRepeat; + uint32_t iter1AlignEnd = RoundUp(iter1OutputCount, elementsPerBlock); + return iter1AlignEnd; +} + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/layout.h b/csrc/mla_preprocess/op_kernel/kernel/layout.h new file mode 100644 index 0000000000..b7b7139373 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/layout.h @@ -0,0 +1,18 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_LAYOUT_H +#define INCLUDE_LAYOUT_H + +enum class DataFormat { ND = 0, NZ, ZN, ZZ, NN, VECTOR }; + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/mem.h b/csrc/mla_preprocess/op_kernel/kernel/mem.h new file mode 100644 index 0000000000..116b14ed3b --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/mem.h @@ -0,0 +1,82 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_MEM_H +#define INCLUDE_MEM_H + +#include "hardware.h" +#include "kernel_event.h" +#include "kernel_tensor.h" + +enum class BufferType { ASCEND_UB, ASCEND_CB, ASCEND_L0A, ASCEND_L0B, ASCEND_L0C, ASCEND_MAX }; + +template +__aicore__ constexpr AscendC::TPosition GetPosition() +{ + if constexpr (BufferType_ == BufferType::ASCEND_UB) { + return AscendC::TPosition::VECIN; + } else if constexpr (BufferType_ == BufferType::ASCEND_CB) { + return AscendC::TPosition::A1; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0A) { + return AscendC::TPosition::A2; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0B) { + return AscendC::TPosition::B2; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0C) { + return AscendC::TPosition::CO1; + } + return AscendC::TPosition::GM; +} + +template +struct AsdopsBuffer { +public: + __aicore__ AsdopsBuffer() + { + constexpr uint32_t bufferSize[(uint32_t)BufferType::ASCEND_MAX] = { + HardwareInfo::ubSize, HardwareInfo::l1Size, HardwareInfo::l0ASize, + HardwareInfo::l0BSize, HardwareInfo::l0CSize}; +#ifdef __DAV_C220_VEC__ + tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]); + tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast(AscendC::TPosition::VECIN); +#elif defined(__DAV_C220_CUBE__) + tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]); + tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast(AscendC::TPosition::A1); + tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]); + tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast(AscendC::TPosition::A2); + tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]); + tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast(AscendC::TPosition::B2); + tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]); + tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast(AscendC::TPosition::CO1); +#else + tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]); + tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast(AscendC::TPosition::VECIN); + tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]); + tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast(AscendC::TPosition::A1); + tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]); + tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast(AscendC::TPosition::A2); + tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]); + tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast(AscendC::TPosition::B2); + tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]); + tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast(AscendC::TPosition::CO1); +#endif + }; + + template + __aicore__ AscendC::LocalTensor GetBuffer(const uint32_t offset) const + { + return tensor[(uint32_t)BufferType_][offset].template ReinterpretCast(); + } + +public: + AscendC::LocalTensor tensor[(uint32_t)BufferType::ASCEND_MAX]; +}; + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/mma.h b/csrc/mla_preprocess/op_kernel/kernel/mma.h new file mode 100644 index 0000000000..ecfa6f873d --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/mma.h @@ -0,0 +1,67 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_MMA_H +#define INCLUDE_MMA_H + +#include "hardware.h" +#include "kernel_tensor.h" + +template +struct mmad { + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {}; + + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint64_t biasBt, uint32_t mTileActual, + uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {}; +}; + +// Partial specialization for V220, int8_t, not_vector_A, not TransposeA +template +struct mmad { + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) + { + AscendC::Mmad(l0cTensor, // C + l0aTensor, // A + l0bTensor, // B + AscendC::MmadParams(mTileActual, // m + nTileActual, // n + kPartActual, // k + unitFlag, // unitFlag + false, // cmatrixSource + initC)); // cmatrixInitVal + }; + + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint64_t biasBt, uint32_t mTileActual, + uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) + { + AscendC::LocalTensor biasTensor; + biasTensor.InitBuffer(biasBt, nTileActual); + biasTensor.address_.logicPos = static_cast(AscendC::TPosition::C2); + AscendC::Mmad(l0cTensor, // C + l0aTensor, // A + l0bTensor, // B + biasTensor, // bt + AscendC::MmadParams(mTileActual, // m + nTileActual, // n + kPartActual, // k + unitFlag, // unitFlag + true, // cmatrixSource + false)); // cmatrixInitVal + }; +}; + +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/set_fpc.h b/csrc/mla_preprocess/op_kernel/kernel/set_fpc.h new file mode 100644 index 0000000000..b32aa6b8a1 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/set_fpc.h @@ -0,0 +1,38 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_SET_FPC_H +#define INCLUDE_SET_FPC_H + +#include "hardware.h" +#include "kernel_tensor.h" + +///////////////////////////////////////////////////// +// SetQuantPreAddr +///////////////////////////////////////////////////// +template +struct SetQuantPreAddr { + __aicore__ SetQuantPreAddr(AscendC::LocalTensor quantPreTensor) {}; +}; + +template +struct SetQuantPreAddr { + static constexpr uint32_t QUANT_PRE_ADDR_MASK = 0xffff; + static constexpr uint32_t USELESS_BIT_NUM = 7; + static constexpr uint32_t QUANT_PRE_BIT_POS_IN_FPC = 8; + + __aicore__ SetQuantPreAddr(AscendC::LocalTensor quantPreTensor) + { + uint64_t quantPreAddr = (uint64_t)(__fbuf__ uint64_t *)quantPreTensor.GetPhyAddr(); + AscendC::SetFixPipeConfigImpl(quantPreTensor); + }; +}; +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/simd.h b/csrc/mla_preprocess/op_kernel/kernel/simd.h new file mode 100644 index 0000000000..a90f83bb1a --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/simd.h @@ -0,0 +1,274 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_SIMD_H +#define INCLUDE_SIMD_H + +#include "hardware.h" +#include "kernel_operator.h" + +///////////////////////////////////////////////////// +// vcgadd +///////////////////////////////////////////////////// +template +__aicore__ inline void cgadd_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, const int32_t repeat, + const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride) +{ + AscendC::BlockReduceSum(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride); +} + +///////////////////////////////////////////////////// +// vadd +///////////////////////////////////////////////////// +template +__aicore__ inline void add_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + AscendC::Add(dst, src0, src1, (uint64_t)0, repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vadds +///////////////////////////////////////////////////// +template +__aicore__ inline void adds_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, DType scalarValue, + uint8_t repeat, uint8_t dstBlockStride, uint8_t srcBlockStride, uint8_t dstRepeatStride, + uint8_t srcRepeatStride) +{ + AscendC::Adds( + dst, src, scalarValue, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vcadd +///////////////////////////////////////////////////// +template +__aicore__ inline void cadd_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride) +{ + AscendC::RepeatReduceSum(dst, src, repeat, 0, 0, srcBlockStride, dstRepeatStride, srcRepeatStride); +} +///////////////////////////////////////////////////// +// vbrcb +///////////////////////////////////////////////////// +template +__aicore__ inline void brcb_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint16_t dstBlockStride, + uint16_t dstRepeatStride, uint8_t repeat) +{ + AscendC::Brcb(dst, src, repeat, AscendC::BrcbRepeatParams(dstBlockStride, dstRepeatStride)); +} + +///////////////////////////////////////////////////// +// vcmax +///////////////////////////////////////////////////// +template +__aicore__ inline void cmax_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride) +{ +#if defined(__DAV_C220_VEC__) + AscendC::WholeReduceMax(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride, + srcRepeatStride, OrderType); +#else + AscendC::WholeReduceMax(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride, + srcRepeatStride); +#endif +} + +///////////////////////////////////////////////////// +// vconv +///////////////////////////////////////////////////// +template +__aicore__ inline void conv_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) +{ + if constexpr (std::is_same::value && std::is_same::value) { + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); + } else { + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_NONE, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); + } +} + +///////////////////////////////////////////////////// +// vconv_f322bf16r +///////////////////////////////////////////////////// +template +__aicore__ inline void convr_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) +{ + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vdiv +///////////////////////////////////////////////////// +template +__aicore__ inline void div_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + AscendC::Div(dst, src0, src1, (uint64_t)0, repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vexp +///////////////////////////////////////////////////// +template +__aicore__ inline void exp_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) +{ + AscendC::Exp( + dst, src, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vmax +///////////////////////////////////////////////////// +template +__aicore__ inline void max_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + AscendC::Max(dst, src0, src1, (uint64_t)0, repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmul +///////////////////////////////////////////////////// +template +__aicore__ inline void mul_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + AscendC::Mul(dst, src0, src1, (uint64_t)0, repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmuls +///////////////////////////////////////////////////// +template +__aicore__ inline void muls_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) +{ + AscendC::Muls( + dst, src0, src1, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vsub +///////////////////////////////////////////////////// +template +__aicore__ inline void sub_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + AscendC::Sub(dst, src0, src1, (uint64_t)0, repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmaxs +///////////////////////////////////////////////////// +template +__aicore__ inline void maxs_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) +{ + AscendC::Maxs( + dst, src0, src1, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vmins +///////////////////////////////////////////////////// +template +__aicore__ inline void mins_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) +{ + AscendC::Mins( + dst, src0, src1, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vsqrt +///////////////////////////////////////////////////// +template +__aicore__ inline void sqrt_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) +{ + AscendC::Sqrt( + dst, src, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vln +///////////////////////////////////////////////////// +template +__aicore__ inline void ln_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) +{ + AscendC::Ln( + dst, src, (uint64_t)0, repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vtranspose +///////////////////////////////////////////////////// +template +__aicore__ inline void tranpose_v(AscendC::LocalTensor dst, AscendC::LocalTensor src) +{ + AscendC::Transpose(dst, src); +} + +///////////////////////////////////////////////////// +// vcgmax +///////////////////////////////////////////////////// +template +__aicore__ inline void cgmax_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, const int32_t repeat, + const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride) +{ + AscendC::BlockReduceMax(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride); +} +#endif diff --git a/csrc/mla_preprocess/op_kernel/kernel/utils.h b/csrc/mla_preprocess/op_kernel/kernel/utils.h new file mode 100644 index 0000000000..932eae2906 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/kernel/utils.h @@ -0,0 +1,69 @@ +/* Adapted from + * https://gitee.com/ascend/ascend-transformer-boost.git + * + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_UTILS_H +#define INCLUDE_UTILS_H + +template +__aicore__ inline void CreateCaMatrix(const AscendC::LocalTensor &dst, const uint16_t repeats, + const uint16_t blockNum, const uint16_t dstGap, const IN_DTYPE initValue) +{ + AscendC::InitConstValue(dst, + AscendC::InitConstValueParams(repeats, blockNum, dstGap, initValue)); +} +__aicore__ inline void SetFftsBaseAddr(uint64_t config) +{ + AscendC::SetSyncBaseAddr(config); +} +template +__aicore__ inline void SetPadding(IN_DTYPE padValue) +{ + AscendC::SetLoadDataPaddingValue(padValue); +} +__aicore__ inline void SetAtomicnone() +{ + AscendC::SetAtomicNone(); +} +__aicore__ inline void SetMasknorm() +{ +#if __CCE_AICORE__ == 100 + return; +#endif + AscendC::SetMaskNorm(); +} +__aicore__ inline void SetNdpara(uint16_t ndNum, uint16_t srcNdStride, uint16_t dstNdStride) +{ + AscendC::SetFixpipeNz2ndFlag(ndNum, srcNdStride, dstNdStride); +} +template +__aicore__ inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow) +{ + AscendC::SetVectorMask(maskHigh, maskLow); +} +__aicore__ inline int64_t GetSubBlockidx() +{ + return AscendC::GetSubBlockIdx(); +} +__aicore__ inline void WaitFlagDev(uint16_t flagId) +{ + AscendC::WaitEvent(flagId); +} +template +__aicore__ inline void FftsCrossCoreSync(uint16_t flagId) +{ + AscendC::CrossCoreSetFlag(flagId); +} +template +__aicore__ inline void SetFpc(const AscendC::LocalTensor &preTensor, bool isUnitFlag = false) +{ + AscendC::SetFixPipeConfig(preTensor, isUnitFlag); +} +#endif diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h new file mode 100644 index 0000000000..35254112b5 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -0,0 +1,114 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#ifndef __MLA_PREPROCESS_H__ +#define __MLA_PREPROCESS_H__ + +// sync +constexpr int32_t QUANT1 = 1; +constexpr int32_t MM1 = 2; +constexpr int32_t MM1QUANT = 3; +constexpr int32_t RMSNORMQUANT2 = 4; +constexpr int32_t MM2 = 5; +constexpr int32_t MM2QUANT = 6; +constexpr int32_t BMM3 = 7; +constexpr int32_t BMM3SPLIT = 8; +constexpr int32_t MM2OUT = 9; +constexpr int32_t EINSUMOUT = 11; +constexpr int32_t EINSUMQUANT = 12; + +// ropeConcat +constexpr uint32_t ELE_NUM_FP16 = 16; // nums of fp16 elements in one block +constexpr uint32_t ELE_NUM_FP32 = 8; // nums of fp32 elements in one block +constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // stride, 8 * 32 = 256 + +// rmsNormQuant +constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float); +constexpr float ZERO = 0; +constexpr uint32_t BUF_FACTOR = 3; // 1(g) + 1(sqx) + 1(sum) = 3 +constexpr uint32_t OFFSET_GAMMA = 0; // the offset of gamma is 0 +constexpr uint32_t OFFSET_SQX = 1; // the offset of sqx is 1 +constexpr uint32_t OFFSET_SUM = 2; // the offset of sum is 2 +constexpr uint32_t OFFSET_WORKSPACE = 3; // the offset of workspace is 3 +constexpr uint32_t REPEAT_TIME_256 = 256; // 128 default stride +constexpr uint32_t REPEAT_TIME_128 = 128; // 128 default stride +constexpr uint32_t REPEAT_TIME_64 = 64; // 64 default stride + +constexpr uint8_t CACHE_MODE_KVCACHE = 0; // single input single output +constexpr uint8_t CACHE_MODE_KROPE_CTKV = 1; // double in and double out +constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format/quant int8 +constexpr uint8_t CACHE_MODE_NZCACHE = 3; + +// pp matmul +constexpr uint32_t HIDDTEN_STATE = 7168; +constexpr uint32_t FLOAT_BLOCK_SIZE = 64; +constexpr uint32_t HALF_BLOCK_SIZE = 64; +constexpr uint32_t HALF_VECTOR_SIZE = 64; +constexpr uint32_t MM1_OUT_SIZE = 2112; +constexpr uint32_t SPLIT_SIZE_ONE = 576; +constexpr uint32_t SPLIT_SIZE_TWO = 1536; +constexpr uint32_t SPLIT_RMSNRORM_SIZE_ONE = 512; +constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64; +constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64; +constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128; + +constexpr uint32_t MMSIZE1 = 128 * 192; // 24576 +constexpr uint32_t MMSIZE2 = 64 * 128; // 8192 + +constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB +constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB +constexpr uint64_t BLOCK_SIZE_16 = 16; +constexpr uint64_t BLOCK_SIZE_32 = 32; +constexpr uint64_t CUBE_MATRIX_SIZE_512 = 16 * 32; // 16 * 23 +constexpr uint64_t FB_BUFF_SIZE = 1024 * 7; +constexpr uint64_t SCALE_L1_LEN = 4096; +constexpr uint64_t BIAS_L1_LEN = 2048; + +constexpr uint64_t CONST_0 = 0; +constexpr uint64_t CONST_4 = 4; +constexpr uint64_t CONST_8 = 8; +constexpr uint64_t CONST_32 = 32; +constexpr uint64_t CONST_64 = 64; +constexpr uint64_t CONST_128 = 128; + +// ropeConcat +constexpr uint32_t ROPE_CONCAT_NUM_BUFFER = 2; + +// rmsNormQuant +constexpr uint32_t OFFSET_ABS = 3; // the offset of abs is 3 +constexpr uint32_t OFFSET_WORKSPACE_BF16 = 4; // the offset of workspace is 4 + +// sync bf16 +constexpr int32_t AIC_MM1_START = 2; +constexpr int32_t AIC_MM3_START = 3; +constexpr int32_t AIC_MM2_START = 6; +constexpr int32_t MMAIC = 7; +constexpr int32_t MMAIV = 8; + +constexpr uint32_t MAX_HW_SYNC_COUNTER = 5; +constexpr uint32_t SYNC_MODE = 2; + +// TilingKey +constexpr uint32_t KEY_FP16_CACHEMODE_0_QUANTMODE_0 = 0; +constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1; +constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256; +constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257; +constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259; + +enum class QuantMode : int32_t { + PER_TENSOR_ASYMM_QUANT = 0, + PER_TOKEN_SYMM_QUANT, + PER_TOKEN_ASYMM_QUANT, + NO_QUANT, +}; + +#endif // __MLA_PREPROCESS_H__ diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp new file mode 100644 index 0000000000..36657cabdd --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -0,0 +1,297 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include "kernel_operator.h" +#include "../../kernels/types.h" + +#include "mla_preprocess_mix_fp16.hpp" +#include "mla_preprocess_mix_bf16.hpp" + +#include "../op_host/tiling/mla_preprocess_tiling.h" + +extern "C" __global__ __aicore__ void mla_preprocess( + GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv, + GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3, + GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq, + GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q, + GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling) +{ + PRELOAD(2); + + SetAtomicnone(); + SetMasknorm(); +#ifdef __DAV_C220_CUBE__ + SetPadding((uint64_t)0); + SetNdpara(1, 0, 0); +#endif + + MlaTilingData mlaTilingData; + __gm__ MlaTilingData *tilingData = reinterpret_cast<__gm__ MlaTilingData *>(tiling); + + mlaTilingData.tilingKey = tilingData->tilingKey; + mlaTilingData.n = tilingData->n; + + mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch; + mlaTilingData.mm1.m = tilingData->mm1.m; + mlaTilingData.mm1.k = tilingData->mm1.k; + mlaTilingData.mm1.n = tilingData->mm1.n; + mlaTilingData.mm1.m0 = tilingData->mm1.m0; + mlaTilingData.mm1.k0 = tilingData->mm1.k0; + mlaTilingData.mm1.n0 = tilingData->mm1.n0; + mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop; + mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop; + mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop; + mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop; + mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount; + mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK; + mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim; + mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat; + mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen; + + mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch; + mlaTilingData.mm2.m = tilingData->mm2.m; + mlaTilingData.mm2.k = tilingData->mm2.k; + mlaTilingData.mm2.n = tilingData->mm2.n; + mlaTilingData.mm2.m0 = tilingData->mm2.m0; + mlaTilingData.mm2.k0 = tilingData->mm2.k0; + mlaTilingData.mm2.n0 = tilingData->mm2.n0; + mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop; + mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop; + mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop; + mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop; + mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount; + mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK; + mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim; + mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat; + mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen; + + mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch; + mlaTilingData.mm3.m = tilingData->mm3.m; + mlaTilingData.mm3.k = tilingData->mm3.k; + mlaTilingData.mm3.n = tilingData->mm3.n; + mlaTilingData.mm3.m0 = tilingData->mm3.m0; + mlaTilingData.mm3.k0 = tilingData->mm3.k0; + mlaTilingData.mm3.n0 = tilingData->mm3.n0; + mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop; + mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop; + mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop; + mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop; + mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount; + mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK; + mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim; + + mlaTilingData.perTaskNum = tilingData->perTaskNum; + mlaTilingData.resTaskNum = tilingData->resTaskNum; + mlaTilingData.numCore = tilingData->numCore; + + mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1; + mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1; + mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2; + mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2; + + mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ; + mlaTilingData.headNumQ = tilingData->headNumQ; + mlaTilingData.headDim = tilingData->headDim; + mlaTilingData.concatSize = tilingData->concatSize; + mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff; + mlaTilingData.ntokens = tilingData->ntokens; + mlaTilingData.realCore = tilingData->realCore; + mlaTilingData.nlCoreRun = tilingData->nlCoreRun; + mlaTilingData.lCoreRun = tilingData->lCoreRun; + mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb; + mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime; + mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast; + mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime; + mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast; + + mlaTilingData.esqFrontCore = tilingData->esqFrontCore; + mlaTilingData.esqTailCore = tilingData->esqTailCore; + mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch; + mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch; + mlaTilingData.esqHeadNum = tilingData->esqHeadNum; + mlaTilingData.esqColNum = tilingData->esqColNum; + mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop; + mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop; + mlaTilingData.esqHeadTail = tilingData->esqHeadTail; + mlaTilingData.esqColLoop = tilingData->esqColLoop; + mlaTilingData.esqColTail = tilingData->esqColTail; + + mlaTilingData.s1Offset = tilingData->s1Offset; + mlaTilingData.s2Offset = tilingData->s2Offset; + mlaTilingData.s3Offset = tilingData->s3Offset; + mlaTilingData.s4Offset = tilingData->s4Offset; + mlaTilingData.s5Offset = tilingData->s5Offset; + + GM_ADDR s1 = workspace + static_cast(mlaTilingData.s1Offset); + GM_ADDR s2 = workspace + static_cast(mlaTilingData.s2Offset); + GM_ADDR s3 = workspace + static_cast(mlaTilingData.s3Offset); + GM_ADDR s4 = workspace + static_cast(mlaTilingData.s4Offset); + GM_ADDR s5 = workspace + static_cast(mlaTilingData.s5Offset); + + switch (mlaTilingData.tilingKey) { + case KEY_FP16_CACHEMODE_0_QUANTMODE_0: { + MLAPO_FP16::MLAOperation opFp16Cm0Qm0( + mlaTilingData, tiling); + opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3); + if ASCEND_IS_AIC { + opFp16Cm0Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opFp16Cm0Qm0.ProcessVector(); + } + break; + } + case KEY_FP16_CACHEMODE_1_QUANTMODE_0: { + MLAPO_FP16::MLAOperation + opFp16Cm1Qm0(mlaTilingData, tiling); + opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3); + if ASCEND_IS_AIC { + opFp16Cm1Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opFp16Cm1Qm0.ProcessVector(); + } + break; + } + case KEY_BF16_CACHEMODE_0_QUANTMODE_0: { + MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm0Qm0(mlaTilingData, tiling); + opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5); + if ASCEND_IS_AIC { + opBf16Cm0Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm0Qm0.ProcessVector(); + } + break; + } + case KEY_BF16_CACHEMODE_1_QUANTMODE_0: { + MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm1Qm0(mlaTilingData, tiling); + opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5); + if ASCEND_IS_AIC { + opBf16Cm1Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm1Qm0.ProcessVector(); + } + break; + } + case KEY_BF16_CACHEMODE_3_QUANTMODE_0: { + MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm3Qm0(mlaTilingData, tiling); + opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5); + if ASCEND_IS_AIC { + opBf16Cm3Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm3Qm0.ProcessVector(); + } + break; + } + default: { + break; + } + } + return; +} + +namespace vllm_ascend { + +extern void mla_preprocess_impl( + void* stream, + void* hidden_state, + void* gamma1, + void* beta1, + void* quant_scale1, + void* quant_offset1, + void* wdqkv, + void* bias1, + void* gamma2, + void* beta2, + void* quant_scale2, + void* quant_offset2, + void* gamma3, + void* sin1, + void* cos1, + void* sin2, + void* cos2, + void* keycache, + void* slot_mapping, + void* wuq, + void* bias2, + void* wuk, + void* descale1, + void* descale2, + void* ctkv_scale, + void* qnope_scale, + void* q, + void* keycache_out, + void* q2, + void* keycache_out2, + void* workspace, + void* tiling, + const uint32_t block_dim) +{ + mla_preprocess<<>>( + hidden_state, + gamma1, + beta1, + quant_scale1, + quant_offset1, + wdqkv, + bias1, + gamma2, + beta2, + quant_scale2, + quant_offset2, + gamma3, + sin1, + cos1, + sin2, + cos2, + keycache, + slot_mapping, + wuq, + bias2, + wuk, + descale1, + descale2, + ctkv_scale, + qnope_scale, + q, + keycache_out, + q2, + keycache_out2, + workspace, + tiling); +} + +} diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp new file mode 100644 index 0000000000..f58f4aa715 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -0,0 +1,2918 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include "kernel/common.h" +#include "kernel/iterator.h" +#include "kernel/mem.h" +#include "kernel/mma.h" +#include "kernel/utils.h" +#include "kernel/simd.h" +#include "kernel/kernel_utils.h" + +#include "lib/matmul_intf.h" + +#include "mla_preprocess.h" +#include "../op_host/tiling/mla_preprocess_tiling.h" + +namespace MLAPO_BF16 { +template +class RopeFp16 +{ +public: + __aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {} + + __aicore__ inline void RopeInit(GM_ADDR qGm, AscendC::GlobalTensor &cosGm, + AscendC::GlobalTensor &sinGm, + AscendC::GlobalTensor &outRopeConcatGm, + AscendC::GlobalTensor &outRopeConcatGm2, MlaTilingData &ropeConcatParams) + { + qGm_.SetGlobalBuffer(reinterpret_cast<__gm__ QkDtype *>(qGm)); + this->cosGm_ = cosGm; + this->sinGm_ = sinGm; + this->outRopeConcatGm_ = outRopeConcatGm; + this->outRopeConcatGm2_ = outRopeConcatGm2; + + headDim = ropeConcatParams.headDim; + headNumQ = ropeConcatParams.headNumQ; + rotaryCoeff = ropeConcatParams.rotaryCoeff; + ntokens = ropeConcatParams.ntokens; + realCore = ropeConcatParams.realCore; + nlCoreRun = ropeConcatParams.nlCoreRun; + lCoreRun = ropeConcatParams.lCoreRun; + maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb; + preCoreLoopTime = ropeConcatParams.preCoreLoopTime; + preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast; + lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; + lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; + concatSize = ropeConcatParams.concatSize; + blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); + loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; + lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; + this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32) + this->rotateStride_ = this->headDim / this->rotaryCoeff; + headBlockLen = static_cast(this->headDim / ELE_NUM_FP16); + headBlockLenFP32 = static_cast(this->headDim / ELE_NUM_FP32); + rotaryLen = static_cast(this->rotateStride_ / ELE_NUM_FP32); + concatBlockLen = static_cast(this->concatSize / ELE_NUM_FP16); + outLineOffset = this->headDim + this->concatSize; + uint32_t dataNum = this->headDim * this->maxNPerLoopForUb; + dataSizeFp16 = dataNum * sizeof(QkDtype); + dataSizeFp32 = dataNum * sizeof(float); + uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb; + } + + __aicore__ inline void Process() + { + if (blockIdx_ >= realCore) { + return; + } + uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun; + // [maxNPerLoopForUb,head_dim] 的 neg + AscendC::LocalTensor negLocal = + buf.GetBuffer(dataSizeFp32 * 4 + dataSizeFp16 * 3); + ExpandNeg(negLocal, this->maxNPerLoopForUb); + + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t zz = 0; zz < this->loopTime; ++zz) { + uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb; + uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb; + uint64_t endHead = startHead + loopN; + + // move in Q + AscendC::LocalTensor inputQ = buf.GetBuffer(0); + AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); + AscendC::LocalTensor reverseQ = + buf.GetBuffer(dataSizeFp32 + dataSizeFp16); + uint64_t qOffset = startHead * 192 + 128; + CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); + + // move in cos/sin + AscendC::LocalTensor inputCos = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16); + AscendC::LocalTensor inputSin = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 2); + uint64_t startSinCosHeadIndex = startHead; + uint64_t headRemain = startHead % this->headNumQ; + uint64_t localStartAddr = 0; + if (headRemain != 0) { + uint64_t preProcessHeadNum = this->headNumQ - headRemain; + uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum; + CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim, + needToProcesHead); + startSinCosHeadIndex += needToProcesHead; + localStartAddr += needToProcesHead * this->headDim; + } + + if (startSinCosHeadIndex < endHead) { + uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ; + uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ; + for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) { + uint32_t repeatNum = + index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ; + CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum); + localStartAddr += this->headDim * this->headNumQ; + } + } + AscendC::LocalTensor inputCosCastFP32 = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 3); + AscendC::LocalTensor inputSinCastFP32 = + buf.GetBuffer(dataSizeFp32 * 3 + dataSizeFp16 * 3); + AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + + uint32_t repeatTime = this->headDim * loopN; + AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime); + AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim); + AscendC::PipeBarrier(); + uint64_t outQOffset = startHead * outLineOffset + this->concatSize; + uint64_t outQOffset2 = startHead * this->headDim; + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + if constexpr (CacheMode == CACHE_MODE_KVCACHE) { + AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen}); + } else { + AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + // tensor -1 -1 -1 1 1 1 + template + __aicore__ inline void ExpandNeg(const AscendC::LocalTensor &tempBuf, uint32_t headNumTemp) + { + for (uint32_t i = 0; i < this->rotateStride_; ++i) { + tempBuf.SetValue(i, (BUF_TYPE)-1); + tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1); + } + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor &tempBufQ, + const AscendC::LocalTensor &tempBufQCast, + const AscendC::LocalTensor &tempBufRverseQ, uint64_t qOffset, + uint16_t loopN) + { + // move in Q + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // cast fp32 + AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + // move out reverseQ + AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyCosSin(const AscendC::LocalTensor &tempBufCos, + const AscendC::LocalTensor &tempBufSin, uint64_t localStartAddr, + uint64_t gmStartAddr, uint64_t repeatNum) + { + AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::PipeBarrier(); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor qGm_; + AscendC::GlobalTensor cosGm_; + AscendC::GlobalTensor sinGm_; + AscendC::GlobalTensor outRopeConcatGm_; + AscendC::GlobalTensor outRopeConcatGm2_; + + uint32_t repeatSize_{0}; + uint32_t rotateStride_{0}; // this->headDim / rope conf + uint32_t headDim; + uint32_t headNumQ; + uint32_t rotaryCoeff; + uint32_t ntokens; + uint32_t realCore; + uint32_t nlCoreRun; + uint32_t lCoreRun; + uint32_t maxNPerLoopForUb; + uint32_t preCoreLoopTime; + uint32_t preCoreLoopNLast; + uint32_t lastCoreLoopTime; + uint32_t lastCoreLoopNLast; + uint32_t concatSize; + uint32_t blockIdx_; + uint32_t loopTime{0}; + uint32_t lastLoopN{0}; + + uint32_t dataSizeFp32; + uint32_t dataSizeFp16; + uint16_t headBlockLen{0}; + uint16_t headBlockLenFP32{0}; + uint16_t rotaryLen{0}; + uint16_t concatBlockLen{0}; + uint64_t outLineOffset{0}; +}; + +__aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor &dst_local, + const AscendC::LocalTensor &src_local, + const AscendC::LocalTensor &work_local, int32_t count) +{ +#ifdef __DAV_C220_VEC__ + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + AscendC::PipeBarrier(); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + AscendC::PipeBarrier(); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + AscendC::PipeBarrier(); + } + AscendC::AscendCUtils::SetMask(NUM_PER_REP_FP32); + cadd_v(dst_local, // dst + work_local, // src + 1, // repeat + 0, // dstRepeatStride + 1, // srcBlockStride + 0); // srcRepeatStride + AscendC::PipeBarrier(); +#endif +} + +template +class Quant +{ +public: + __aicore__ inline Quant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor &quantScaleGmTensor, + AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, + GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, + uint32_t num_col, uint64_t gm_offset, uint64_t gm_out_offset, uint32_t row_work_, + const MlaTilingData &mlaParams_) + { + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); + this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); + if constexpr (!NEED_DEQUANT) { + inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); + } else { + mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); + } + outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); + + num_col_ = num_col; + quantMin_ = -128; + this->num_row_ = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 + AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 + AscendC::LocalTensor perTokenDescaleTensor = + buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 + + SET_FLAG(MTE2, V, EVENT_ID1); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + } + + if constexpr (NEED_DEQUANT) { + mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); + } + + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (std::is_same::value) { + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + input_scale_ = 1 / (float)(g.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } else { + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID1); + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } else { + /* Dequant start */ + AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + num_col_); + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], + 0, // sid + 1, // nBurst + sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + float perTokenDescale = perTokenDescaleTensor.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + num_col_); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); + AscendC::PipeBarrier(); + } + + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + + /* Quant start */ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + ReduceMax(max, abs, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + float scaleOut = max.GetValue(0) / 127; + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + perTokenDescaleTensor.SetValue(0, scaleOut); + SET_FLAG(S, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], + perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + } + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + AscendC::GlobalTensor perTokenDescaleGmTensor; + AscendC::GlobalTensor perChannelDescaleGmTensor; + AscendC::GlobalTensor mmGmTensor; + + uint32_t num_col_{0}; // input columns + uint32_t num_row_{0}; // input rows + uint32_t row_work_{0}; // rows need process + uint32_t row_work{0}; // rows need process + uint32_t row_step_{0}; // rows move in once + uint32_t row_tail_{0}; // rows move in last time + uint64_t gm_offset_{0}; // GM data offset + uint64_t gm_out_offset_{0}; // GM data offset + float avg_factor_{1.0}; // 1/num_col_ + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +template +class RmsNormQuant +{ +public: + __aicore__ inline RmsNormQuant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor &gammaGmTensor, AscendC::GlobalTensor &betaGmTensor, + AscendC::GlobalTensor &quantScaleGmTensor, + AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, + GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, + uint32_t num_col, float avg_factor, uint64_t gm_offset, uint64_t gm_out_offset, + uint32_t row_work_, const MlaTilingData &mlaParams_) + { + this->gammaGmTensor = gammaGmTensor; + this->betaGmTensor = betaGmTensor; + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); + this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); + if constexpr (!NEED_DEQUANT) { + inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); + } else { + mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); + } + outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); + + num_col_ = num_col; + avg_factor_ = avg_factor; + epsilon_ = 1e-6; + quantMin_ = -128; + this->num_row_ = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->gammaTensor = gammaTensor; + this->betaTensor = betaTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 + AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 + AscendC::LocalTensor perTokenDescaleTensor = + buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 + + AscendC::DataCopy(gammaTensor, gammaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopy(betaTensor, betaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + } + + if constexpr (NEED_DEQUANT) { + mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); + } + + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (std::is_same::value) { + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + input_scale_ = 1 / (float)(g.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } else { + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE, + REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } else { + /* Dequant start */ + AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + num_col_); + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], + 0, // sid + 1, // nBurst + sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + float perTokenDescale = perTokenDescaleTensor.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + num_col_); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); + AscendC::PipeBarrier(); + } + + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_); + AscendC::PipeBarrier(); + ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + Adds(sum, sum, epsilon_, 1); + AscendC::PipeBarrier(); + Sqrt(sum, sum, 1); + SET_FLAG(V, S, EVENT_ID0); + WAIT_FLAG(V, S, EVENT_ID0); + float factor = 1 / sum.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + if constexpr (WITH_BETA) { + AscendC::LocalTensor b = this->betaTensor; + Cast(work, b, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Add(fp32_xy, fp32_xy, work, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } + /* Quant start */ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + ReduceMax(max, abs, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + float scaleOut = max.GetValue(0) / 127; + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + perTokenDescaleTensor.SetValue(0, scaleOut); + SET_FLAG(S, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], + perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + } + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor gammaTensor; + AscendC::LocalTensor betaTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor gammaGmTensor; + AscendC::GlobalTensor betaGmTensor; + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + AscendC::GlobalTensor perTokenDescaleGmTensor; + AscendC::GlobalTensor perChannelDescaleGmTensor; + AscendC::GlobalTensor mmGmTensor; + + uint32_t num_col_{0}; + uint32_t num_row_{0}; + uint32_t row_work_{0}; + uint32_t row_work{0}; + uint32_t row_step_{0}; + uint32_t row_tail_{0}; + uint64_t gm_offset_{0}; + uint64_t gm_out_offset_{0}; + float avg_factor_{1.0}; + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +template +class EinSumQuant +{ +public: + __aicore__ explicit EinSumQuant() {} + + __aicore__ __force_inline__ void Init(GM_ADDR einSumOutGm, GM_ADDR scaleGm, GM_ADDR quantOutGm, + const MlaTilingData &tilingData) + { + einSumOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(einSumOutGm)); + scaleGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(scaleGm)); + quantOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOutGm)); + + headNum = tilingData.esqHeadNum; + colNum = tilingData.esqColNum; + ubHeadLoop = tilingData.esqUbHeadLoop; + headPerLoop = tilingData.esqHeadPerLoop; + headTail = tilingData.esqHeadTail; + colLoop = tilingData.esqColLoop; + colTail = tilingData.esqColTail; + + currentIdx = (AscendC::GetBlockIdx() / 2) * 2 + GetSubBlockidx(); + if (currentIdx < tilingData.esqFrontCore) { + batchNum = tilingData.esqFrontCoreBatch; + currentCoreStartOffset = currentIdx * tilingData.esqFrontCoreBatch * headNum * colNum; + } else { + batchNum = tilingData.esqTailCoreBatch; + currentCoreStartOffset = (tilingData.esqFrontCore * tilingData.esqFrontCoreBatch + + (currentIdx - tilingData.esqFrontCore) * tilingData.esqTailCoreBatch) * + headNum * colNum; + } + calcRepeatStride = static_cast(colNum / ELE_NUM_FP32); + padLen = RoundUp(headNum, ELE_NUM_FP16); + calcLength = headPerLoop * colNum; + + // calc tensors' data size(bytes) and block + scaleBrcbFp32DataSize = padLen * ELE_NUM_FP32 * sizeof(float); + inputDataSize = calcLength * sizeof(InDtype); + inputDataBlock = calcLength * sizeof(InDtype) / BLOCK_SIZE_32; + inputFp32DataSize = calcLength * sizeof(float); + int8OutDataBlcok = calcLength / BLOCK_SIZE_32; + headTailDataBlock = headTail * colNum * sizeof(InDtype) / BLOCK_SIZE_32; + int8TailOutDataBlock = headTail * colNum / BLOCK_SIZE_32; + if (padLen > headNum) { + scaleCopyParams = AscendC::DataCopyExtParams(1, static_cast(headNum * sizeof(InDtype)), 0, 0, 0); + scalePadParams = AscendC::DataCopyPadExtParams(true, 0, static_cast(padLen - headNum), 0); + } + } + + __aicore__ __force_inline__ void Process() + { + if (batchNum == 0) { + return; + } + // init local tensor + scaleBrcbFp32_ = buf.GetBuffer(0); + inputTensor_ = buf.GetBuffer(scaleBrcbFp32DataSize); + inputFp32_ = + buf.GetBuffer(scaleBrcbFp32DataSize + inputDataSize * ROPE_CONCAT_NUM_BUFFER); + int8OutTensor_ = buf.GetBuffer( + scaleBrcbFp32DataSize + (inputDataSize + inputFp32DataSize) * ROPE_CONCAT_NUM_BUFFER); + + // scale copy in, cast, brcb[H, 1] --> [H, 8], use input ub space + if (headNum == padLen) { + AscendC::DataCopy(inputTensor_, scaleGm_, headNum); + } else { + AscendC::DataCopyPad(inputTensor_, scaleGm_, scaleCopyParams, scalePadParams); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(inputFp32_, inputTensor_, AscendC::RoundMode::CAST_NONE, padLen); + AscendC::PipeBarrier(); + AscendC::Brcb(scaleBrcbFp32_, inputFp32_, padLen / ELE_NUM_FP32, {1, 8}); + AscendC::PipeBarrier(); + + uint8_t pingFlag = 0; + // batch Loop + SET_FLAG(V, MTE2, EVENT_ID0); // input copy in wait vector release ub + SET_FLAG(V, MTE2, EVENT_ID1); + SET_FLAG(MTE3, V, EVENT_ID0); // quant calc wait last result copyout + SET_FLAG(MTE3, V, EVENT_ID1); + for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) { + batchOffset = batchIdx * headNum * colNum; + // ub Loop + for (uint32_t ubLoopIdx = 0; ubLoopIdx < ubHeadLoop; ubLoopIdx++) { + scaleBrcbOffset = ubLoopIdx * headPerLoop * ELE_NUM_FP32; + inputLoopOffset = ubLoopIdx * headPerLoop * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + calcTmpOffset = pingFlag * calcLength; + + // input CopyIn and Cast + WAIT_FLAG(V, MTE2, pingFlag); + AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], + {1, inputDataBlock, 0, 0}); + SET_FLAG(MTE2, V, pingFlag); + WAIT_FLAG(MTE2, V, pingFlag); + AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, + calcLength); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE2, pingFlag); + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_64; + AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], + scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headPerLoop, + {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + // quant fp32 --> fp16 --> int8 + CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], + calcLength); + AscendC::PipeBarrier(); + WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out + CastFromF16ToI8(int8OutTensor_[calcTmpOffset], + inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, calcLength); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, pingFlag); + WAIT_FLAG(V, MTE3, pingFlag); + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], + {1, int8OutDataBlcok, 0, 0}); + SET_FLAG(MTE3, V, pingFlag); + pingFlag = 1 - pingFlag; + } + + // deal with head tail + if (headTail > 0) { + scaleBrcbOffset = ubHeadLoop * headPerLoop * ELE_NUM_FP32; + inputLoopOffset = ubHeadLoop * headPerLoop * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + calcTmpOffset = pingFlag * calcLength; + + // input CopyIn and Cast + WAIT_FLAG(V, MTE2, pingFlag); + AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], + {1, headTailDataBlock, 0, 0}); + SET_FLAG(MTE2, V, pingFlag); + WAIT_FLAG(MTE2, V, pingFlag); + AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, + headTail * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE2, pingFlag); + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_64; + AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], + scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headTail, + {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + // quant fp32 --> fp16 --> int8 + CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], + headTail * colNum); + AscendC::PipeBarrier(); + WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out + CastFromF16ToI8(int8OutTensor_[calcTmpOffset], + inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, + headTail * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, pingFlag); + WAIT_FLAG(V, MTE3, pingFlag); + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], + {1, int8TailOutDataBlock, 0, 0}); + SET_FLAG(MTE3, V, pingFlag); + pingFlag = 1 - pingFlag; + } + } + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID1); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor einSumOutGm_; + AscendC::GlobalTensor scaleGm_; + AscendC::GlobalTensor quantOutGm_; + + AscendC::LocalTensor scaleBrcbFp32_; + AscendC::LocalTensor inputTensor_; + AscendC::LocalTensor inputFp32_; + AscendC::LocalTensor int8OutTensor_; + + AscendC::DataCopyExtParams scaleCopyParams; + AscendC::DataCopyPadExtParams scalePadParams; + + // data processed by a single core[batchNum, headNum, colNum] + uint32_t batchNum; // The number of batches per kernel processed + uint32_t headNum; + uint32_t colNum; // Number of columns per row + // ub loop + uint32_t ubHeadLoop; // The number of times the UB loops through the head. + uint32_t headPerLoop; // The number of heads processed per UB cycle + uint32_t headTail; // The number of heads last processed + // col loop + uint32_t colLoop; // The number of calculations in the column direction cycle. + uint32_t colTail; // The number of cols last processed + + uint32_t currentIdx; + uint64_t currentCoreStartOffset; + uint32_t inputDataSize; // The size of each carry,bytes + uint32_t inputFp32DataSize; + uint32_t scaleBrcbFp32DataSize; + uint16_t inputDataBlock; // The number of blocks brought in per move,bytes + uint16_t int8OutDataBlcok; + uint16_t headTailDataBlock; + uint16_t int8TailOutDataBlock; + // gm offset + uint64_t inputLoopOffset{0}; + uint64_t batchOffset{0}; + uint64_t calcStartOffset{0}; + // double buffer tmp tensor length + uint32_t scaleBrcbOffset{0}; + uint32_t calcLength{0}; + uint32_t calcTmpOffset{0}; + + half quantMin_{-128}; + uint32_t colOffset{0}; + uint32_t padLen; + uint8_t calcRepeatStride; +}; + +#ifdef __DAV_C220_CUBE__ +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +template +class PpMatmulEinSum +{ + using AccumDtype = float; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; + static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; + static constexpr uint32_t CONST_16 = 16; + static constexpr uint32_t CONST_256 = 256; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams) + { +#ifdef __DAV_C220_CUBE__ + batch_size = mlaParams.mm3.numBatch; + m = mlaParams.mm3.m; + k = mlaParams.mm3.k; + n = mlaParams.mm3.n; + m0 = mlaParams.mm3.m0; + k0 = mlaParams.mm3.k0; + n0 = mlaParams.mm3.n0; + tdim.m = mlaParams.mm3.mLoop; + tdim.k = mlaParams.mm3.kLoop; + tdim.n = mlaParams.mm3.nLoop; + core_loop = mlaParams.mm3.coreLoop; + swizzle_cnt = mlaParams.mm3.swizzleCount; + num_core = mlaParams.mm3.blockDim; + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + AsdopsBuffer buf; + l1_base_a = buf.GetBuffer(0); + l1_base_b = buf.GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); + l0a_base = buf.GetBuffer(0); + l0b_base = buf.GetBuffer(0); +#endif + return; + } + + __aicore__ __force_inline__ void Process() + { +#ifdef __DAV_C220_CUBE__ + if (block_idx >= num_core) { + WaitFlagDev(AIC_MM3_START); + return; + } + using LocalTensor = AscendC::LocalTensor; + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx, tidx); + uint64_t offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0; + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + WaitFlagDev(AIC_MM3_START); + + // Copy A from gm to l1 buffer + uint64_t offset_a = GetOffsetA(batch_idx, tidx.m, shuffle_k); + WAIT_FLAG(MTE1, MTE2, event_id); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + SET_FLAG(MTE2, MTE1, event_id); + + // Copy B from gm to l1 buffer + uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n); + WAIT_FLAG(MTE1, MTE2, event_id + 2); + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + uint64_t offset_a_next = GetOffsetA(batch_idx, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n); + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t offset_a_next = GetOffsetA(b_idx_next, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n); + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next, + k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next, + n_round_next); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if constexpr (transB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + (n + splitGapC) * batch_size); // nActual + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); +#endif + } + +private: + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx) + { + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (swizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (swizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + return; + } + + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t bIdx, const uint64_t mIdx, const uint64_t kIdx) + { + return mIdx * m0 * batch_size * (k + splitGapA) + bIdx * (k + splitGapA) + kIdx * k0; + } + + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + return bIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return bIdx * k * n + kIdx * k0 * n + nIdx * n0; + } + } else { + if constexpr (transB) { + return bIdx * RoundUp(n) * RoundUp(k) + kIdx * k0 * RoundUp(n) + + nIdx * n0 * CONST_16; + } else { + return bIdx * RoundUp(k) * RoundUp(n) + nIdx * n0 * RoundUp(k) + + kIdx * k0 * CONST_16; + } + } + } + + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) + { + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, // nTileActual + CONST_16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + (k + splitGapA) * batch_size); // dVal + } + } + + __aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k{0}; + uint32_t ping_flag{0}; +}; + +template +class PpMatmulW8a8Aic +{ + using InDtype = int8_t; + using OutDtype = int32_t; + using AccumDtype = int32_t; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mmad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; + static constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; + static constexpr uint64_t BLOCK_SIZE_16 = 16; + static constexpr uint64_t BLOCK_SIZE_32 = 32; + static constexpr uint64_t CUBE_MATRIX_SIZE_512 = 512; + static constexpr uint64_t CONST_4 = 4; + static constexpr uint64_t CONST_8 = 8; + static constexpr uint64_t CONST_32 = 32; + static constexpr uint64_t CONST_64 = 64; + static constexpr uint64_t CONST_128 = 128; + +public: + __aicore__ PpMatmulW8a8Aic() {}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, PpMatmulTilingData &tilingdata, + uint32_t mode) + { + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + batch_size = tilingdata.numBatch; + m = tilingdata.m; + k = tilingdata.k; + n = tilingdata.n; + m0 = tilingdata.m0; + k0 = tilingdata.k0; + n0 = tilingdata.n0; + m_loop = tilingdata.mLoop; + k_loop = tilingdata.kLoop; + n_loop = tilingdata.nLoop; + core_loop = tilingdata.coreLoop; + swizzle_cnt = tilingdata.swizzleCount; + en_shuffle_k = tilingdata.enShuffleK; + core_num = tilingdata.blockDim; + load_all_Amat_flag = tilingdata.enLoadAllAmat; + b0mat_pingpong_buffer_len = tilingdata.b0matPingPongBufferLen; + + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + MM1_MM2_mode = mode; // MM1 or MM2 + + InitBuffer(); + return; + } + + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t batchIdx, const uint64_t mIdx, uint64_t kIdx) + { + return batchIdx * m * k + mIdx * m0 * k + kIdx * k0; + } + + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t batchIdx, const uint64_t kIdx, uint64_t nIdx) + { + if constexpr (formatB == DataFormat::ND) { + return batchIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return batchIdx * RoundUp<16>(n) * RoundUp<32>(k) + kIdx * k0 * RoundUp<16>(n) + nIdx * n0 * CONST_32; + } + } + + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) + { + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, BLOCK_SIZE_16, 1, k_actual, k_round, k); + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } + } + + __aicore__ __force_inline__ void CopyTileB(const AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) + { + if constexpr (formatB == DataFormat::ND) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp<16>(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp<32>(k)); // dVal + } + } + + __aicore__ __force_inline__ void PreloadWeight() + { + if (core_idx < core_num) { + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(core_idx, m_idx, n_idx); + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + if (core_idx < core_num && k_loop > 1) { + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(core_idx, m_idx, n_idx); + uint64_t shuffle_k = en_shuffle_k ? (core_idx + 1) % k_loop : 1; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + CopyTileB(l1_base_b[b0mat_pingpong_buffer_len], gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + } + + __aicore__ __force_inline__ void Process(); + +private: + __aicore__ __force_inline__ void InitBuffer() + { + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(0); + + // try load all A matrix + uint32_t a_l1_size = RoundUp(m) * RoundUp(k); + if (!load_all_Amat_flag) { + a_l1_size = RoundUp(m0 * k0); + } + + l1_base_b = l1_base_a[a_l1_size]; + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); + l0c_buf = buf.template GetBuffer(0); + } + + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, uint64_t &m_idx, uint64_t &n_idx) + { + uint64_t in_batch_idx = index % (m_loop * n_loop); + if constexpr (swizzleDir == 0) { // Zn + uint64_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzle_cnt * tile_block_idx; + } + m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if ((tile_block_idx & 0b1) != 0) { + n_idx = n_loop - n_idx - 1; + } + } else { // Nz + uint64_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzle_cnt * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if ((tile_block_idx & 0b1) != 0) { + m_idx = m_loop - m_idx - 1; + } + } + return; + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint64_t bias_bt{0}; + uint32_t core_num{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t m_loop{0}; + uint32_t n_loop{0}; + uint32_t k_loop{0}; + uint32_t core_loop{0}; + uint32_t core_idx{0}; + uint32_t ping_flag{0}; + uint32_t swizzle_cnt{1}; + uint32_t en_shuffle_k{0}; + uint32_t MM1_MM2_mode{0}; + uint64_t b0mat_pingpong_buffer_len{0}; + bool load_all_Amat_flag{false}; +}; + +template +__aicore__ __force_inline__ void PpMatmulW8a8Aic::Process() +{ + using LocalTensor = AscendC::LocalTensor; + if (core_idx >= core_num) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(AIC_MM1_START); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(AIC_MM2_START); + } + return; + } + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID7); + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { + uint64_t batch_idx = loop_idx / n_loop / m_loop; + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(loop_idx, m_idx, n_idx); + uint64_t offset_a; + uint64_t offset_b; + uint64_t offset_bias; + uint64_t offset_a_next; + uint64_t offset_b_next; + uint64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + uint64_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t m_round = 0; + uint64_t n_round = 0; + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t m_round_16 = RoundUp(m_actual); + uint64_t m_round_32 = RoundUp(m_actual); + m_round = m_round_16; + n_round = RoundUp(n_actual); + + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = 0; + k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / BLOCK_SIZE_32 * BLOCK_SIZE_32; + + offset_b = GetOffsetB(batch_idx, shuffle_k, n_idx); + offset_bias = batch_idx * n + n_idx * n0; + + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + // Wait after Scalar + if (loop_idx == core_idx) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(AIC_MM1_START); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(AIC_MM2_START); + } + } + + WAIT_FLAG(MTE1, MTE2, event_id); + LocalTensor l1_buf_a = + load_all_Amat_flag ? l1_base_a : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + if (load_all_Amat_flag) { + if (loop_idx == core_idx) { + offset_a = GetOffsetA(batch_idx, m_idx, 0); + uint64_t k_actual_first = k; + uint64_t k_round_first = RoundUp(k_actual_first); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual_first, k_round_first); + } + } else { + offset_a = GetOffsetA(batch_idx, m_idx, shuffle_k); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + } + SET_FLAG(MTE2, MTE1, event_id); + + WAIT_FLAG(MTE1, MTE2, event_id + CONST_2); + // The first weight matrix block is loaded in advance. + if (loop_idx != core_idx) { + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id + CONST_2); + + for (uint64_t k_idx = 0; k_idx < k_loop; k_idx++) { + shuffle_k = en_shuffle_k ? (k_idx + core_idx) % k_loop : k_idx; + uint32_t k_actual = (shuffle_k == (k_loop - 1)) ? (k - shuffle_k * k0) : k0; + uint32_t k_round = RoundUp(k_actual); + uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len; + + // --------- load whole A in l1a addr change ------------- + LocalTensor l1_buf_a = load_all_Amat_flag ? (l1_base_a[k_idx * m0 * k0 * sizeof(int8_t)]) + : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (k_idx < k_loop - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + k_idx + 1) % k_loop : k_idx + 1; + + offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, n_idx); + uint32_t k_actual_next = (shuffle_k_next == (k_loop - 1)) ? (k - shuffle_k_next * k0) : k0; + uint32_t k_round_next = RoundUp(k_actual_next); + + LocalTensor l1_buf_a_next = + load_all_Amat_flag ? l1_base_a : ((1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + if (!load_all_Amat_flag) { + offset_a_next = GetOffsetA(batch_idx, m_idx, shuffle_k_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + } + SET_FLAG(MTE2, MTE1, event_id_next); + + WAIT_FLAG(MTE1, MTE2, event_id_next + CONST_2); + // The second weight matrix is preloaded. + if (loop_idx != core_idx || k_idx != 0) { + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id_next + CONST_2); + } + + for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { + uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len; + uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len; + + auto mte1_mad_ping_flag = 1 - k_part_idx % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor l0a_buf = l0a_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + AscendC::LocalTensor l0b_buf = l0b_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, l1_buf_a[k_part_idx * k_part_len], + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / BLOCK_SIZE_16, // kSrcStride + k0_round / BLOCK_SIZE_32, // mDstStride + 1); // kDstStride + } + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + CONST_2); + } + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / BLOCK_SIZE_16, // kSrcStride + 1, // nDstStride + k0_round / BLOCK_SIZE_32); // kDstStride + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id + CONST_2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (k_idx == 0 && k_part_idx == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + Mmad(l0c_buf, l0a_buf, l0b_buf, + m_actual, // m + n_actual, // n + k0_actual, // k + init_c); // cmatrixInitVal + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // MSize + n_actual, // NSize + m_round_16, // srcStride + n); // dstStride_dst_D + SET_FLAG(FIX, M, EVENT_ID0); + if constexpr (!withSyncAll) { + FftsCrossCoreSync(MMAIC); + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 0) { + WaitFlagDev(MMAIV); + } + } + } + + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID7); +} + +#endif + +#if defined(__DAV_C220_VEC__) + +template +class PpMatmulW8a8Aiv +{ + using InDtype = int32_t; + using ScaleDtype = float; + using BiasDtype = int32_t; + +public: + __aicore__ PpMatmulW8a8Aiv() {}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmInput, GM_ADDR gmOutput, GM_ADDR gmDescale, GM_ADDR gmPerTensorBias, + GM_ADDR gmPertokenDescale, const PpMatmulTilingData &gmTilingData) + { + gmInput_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmInput)); + gmOutput_.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmOutput)); + gmPerTensorScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmDescale)); + gmPerTensorBias_.SetGlobalBuffer(reinterpret_cast<__gm__ BiasDtype *>(gmPerTensorBias)); + gmPerTokenScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmPertokenDescale)); + + batch_size = gmTilingData.numBatch; + m = gmTilingData.m; + k = gmTilingData.k; + n = gmTilingData.n; + m0 = gmTilingData.m0; + k0 = gmTilingData.k0; + n0 = gmTilingData.n0; + m_loop = gmTilingData.mLoop; + k_loop = gmTilingData.kLoop; + n_loop = gmTilingData.nLoop; + core_loop = gmTilingData.coreLoop; + swizzle_cnt = gmTilingData.swizzleCount; + swizzlDirect = gmTilingData.swizzleDirect; + en_shuffle_k = gmTilingData.enShuffleK; + + AsdopsBuffer buf; + ubInput_ = buf.GetBuffer(0); + ubTempFp32_ = buf.GetBuffer(94 * 1024); + ubOutput_ = buf.GetBuffer(0); + ubPerTensorScale_ = buf.GetBuffer(188 * 1024); + block_size = BLOCK_SIZE_32; + core_num = AscendC::GetBlockNum(); + core_idx = AscendC::GetBlockIdx() / 2; + ping_flag = 1; + } + + __aicore__ __force_inline__ void GetBlockIdx(uint32_t index, uint32_t &m_idx, uint32_t &n_idx) + { + uint32_t in_batch_idx = index % (m_loop * n_loop); + if (swizzlDirect == 0) { // Zn + uint32_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; + uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); + + uint32_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzle_cnt * tile_block_idx; + } + m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + n_idx = n_loop - n_idx - 1; + } + } else { // Nz + uint32_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; + uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); + + uint32_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzle_cnt * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + m_idx = m_loop - m_idx - 1; + } + } + } + + __aicore__ __force_inline__ void Process(); + +private: + AscendC::GlobalTensor gmPerTensorScale_; + AscendC::GlobalTensor gmPerTensorBias_; + AscendC::GlobalTensor gmPerTokenScale_; + AscendC::GlobalTensor gmInput_; + AscendC::GlobalTensor gmOutput_; + + AscendC::LocalTensor ubInput_; + AscendC::LocalTensor ubTempFp32_; + AscendC::LocalTensor ubOutput_; + AscendC::LocalTensor ubPerTensorScale_; + + uint32_t core_num{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t m_loop{0}; + uint32_t n_loop{0}; + uint32_t k_loop{0}; + uint32_t core_loop{0}; + uint32_t core_idx{0}; + uint32_t ping_flag{0}; + uint32_t block_size{0}; + uint32_t cube_matrix_size{0}; + uint32_t swizzle_cnt{1}; + uint32_t en_shuffle_k{0}; + uint32_t swizzlDirect{0}; + uint64_t L1_PINGPONG_BUFFER_LEN{0}; + uint32_t L0AB_PINGPONG_BUFFER_LEN{0}; +}; + +template +__aicore__ __force_inline__ void PpMatmulW8a8Aiv::Process() +{ + uint32_t m_idx = 0; + uint32_t n_idx = 0; + SET_FLAG(V, MTE2, EVENT_ID0); + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + for (uint32_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { + GetBlockIdx(loop_idx, m_idx, n_idx); + uint64_t batch_idx = loop_idx / n_loop / m_loop; + uint64_t offsetC = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + uint32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + uint32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint32_t m_round = RoundUp(m_actual); + uint32_t n_round = RoundUp(n_actual); + uint32_t n_round_16 = RoundUp(n_actual); + uint32_t m_actual_per_vec = m_actual / AscendC::GetTaskRation(); + uint32_t m_offset = m + m_idx * m0; + if (GetSubBlockidx() != 0) { + offsetC += m_actual_per_vec * n; + m_offset += m_actual_per_vec; + m_actual_per_vec = m_actual - m_actual_per_vec; + } + + if constexpr (!withSyncAll) { + if (m_actual_per_vec == 0) { + WaitFlagDev(MMAIC); + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { + FftsCrossCoreSync(MMAIV); + } + continue; + } + } + + uint64_t offsetScale = batch_idx * n + n_idx * n0; + bool aligned_s32 = ((n & 0b111) == 0); // 32B aligned + bool aligned_f16 = ((n & 0b1111) == 0); // 32B aligned + WAIT_FLAG(V, MTE2, EVENT_ID0); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_.ReinterpretCast(), + gmPerTensorBias_[offsetScale], + 0, // sid + 1, // nBurst + n_round * sizeof(BiasDtype) / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0); // dstStride + } else { + gm_to_ub_align(ubPerTensorScale_.ReinterpretCast(), + gmPerTensorBias_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0); // dstGap + } + } else { + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_round * 4 / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0); // dstStride + } else { + gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0); // dstGap + } + } + + if constexpr (!withSyncAll) { + WaitFlagDev(MMAIC); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if (aligned_s32) { + gm_to_ub(ubInput_, gmInput_[offsetC], + 0, // sid + m_actual_per_vec, // nBurst + n_round / 8, // lenBurst + (n - n_round) / 8, // srcStride + 0 // dstStride + ); + } else { + gm_to_ub_align(ubInput_, gmInput_[offsetC], + 0, // sid + m_actual_per_vec, // nBurst + n_actual * sizeof(int32_t), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (n - n_actual) * sizeof(int32_t), // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + + WAIT_FLAG(MTE3, V, EVENT_ID0); + uint32_t nRepeatCnt = CeilDiv(n_actual); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::SetMaskCount(); + AscendC::SetVectorMask(n_round); + for (uint32_t i = 0; i < m_actual_per_vec; ++i) { + // add_v(ubInput_[i * n_round], + // ubInput_[i * n_round], + // ubPerTensorScale_.ReinterpretCast(), + // (uint8_t)(nRepeatCnt), // repeat + // (uint8_t)1, // dstBlockStride + // (uint8_t)1, // src0BlockStride + // (uint8_t)1, // src1BlockStride + // (uint8_t)8, // dstRepeatStride + // (uint8_t)8, // src0RepeatStride + // (uint8_t)8 // src1RepeatStride + // ); + AscendC::Add(ubInput_[i * n_round], ubInput_[i * n_round], + ubPerTensorScale_.ReinterpretCast(), + AscendC::MASK_PLACEHOLDER, 1, + AscendC::BinaryRepeatParams((uint8_t)1, (uint8_t)1, (uint8_t)1, + (uint8_t)8, (uint8_t)8, (uint8_t)8)); + } + AscendC::ResetMask(); + SetMasknorm(); + + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_round * sizeof(ScaleDtype) / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0 // dstStride + ); + } else { + gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(ScaleDtype), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } + + // CASTF32 * f32 tf16 + constexpr uint32_t maxRepeat = 255; + constexpr uint32_t perRepeatNum = maxRepeat * 64; + uint32_t loopCnt = (m_actual_per_vec * n_actual + perRepeatNum - 1) / perRepeatNum; + for (uint32_t i = 0; i < loopCnt; i++) { + conv_v(ubInput_.ReinterpretCast()[perRepeatNum * i], + ubInput_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)8, // dstRepeatStride + (uint16_t)8 // srcRepeatStride + ); + } + AscendC::PipeBarrier(); + + for (uint32_t i = 0; i < m_actual_per_vec; ++i) { + mul_v(ubTempFp32_[i * n_round], + ubInput_.ReinterpretCast()[i * n_round], + ubPerTensorScale_.ReinterpretCast(), + (uint8_t)(nRepeatCnt), // repeat + (uint8_t)1, // dstBlockStride + (uint8_t)1, // src0BlockStride + (uint8_t)1, // src1BlockStride + (uint8_t)8, // dstRepeatStride + (uint8_t)8, // src0RepeatStride + (uint8_t)8 // src1RepeatStride + ); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + AscendC::PipeBarrier(); + float perTokenDescale = gmPerTokenScale_.GetValue(m_offset + i); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(ubTempFp32_[i * n_round], ubTempFp32_[i * n_round], perTokenDescale, n_round); + } + AscendC::PipeBarrier(); + } + SET_FLAG(V, MTE2, EVENT_ID0); + AscendC::PipeBarrier(); + if (n_actual % 16 > 8) { + for (uint32_t i = 0; i < loopCnt; i++) { + if constexpr (std::is_same_v) { + convr_v(ubOutput_[perRepeatNum * i], + ubTempFp32_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } else { + conv_v(ubOutput_[perRepeatNum * i], + ubTempFp32_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } + } + } else { + for (uint32_t i = 0; i < m_actual_per_vec; i++) { + if constexpr (std::is_same_v) { + convr_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], + (uint8_t)nRepeatCnt, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } else { + conv_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], + (uint8_t)nRepeatCnt, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } + } + } + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + if (aligned_f16) { + ub_to_gm(gmOutput_[offsetC], ubOutput_, 0, + m_actual_per_vec, // nBurst + n_round / 16, // lenBurst + 0, // srcStride + (n - n_round) / 16 // dstStride + ); + } else { + ub_to_gm_align(gmOutput_[offsetC], ubOutput_, 0, + m_actual_per_vec, // nBurst + n_actual * sizeof(OutDtype), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (n - n_actual) * sizeof(OutDtype) // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!withSyncAll) { + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { + FftsCrossCoreSync(MMAIV); + } + } + } + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); +} +#endif + +template +class MLAOperation +{ + static constexpr bool mm1WithSyncAll = (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT); + static constexpr uint64_t splitGapC = CACHE_MODE == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; + using Q_OUT_DTYPE = typename std::conditional_t; + using K_NOPE_DTYPE = typename std::conditional_t; + +public: + __aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm) + { + blockIdx = AscendC::GetBlockIdx(); +#ifdef __DAV_C220_VEC__ + sub_block_idx = static_cast(GetSubBlockidx()); +#endif + vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx; + this->n = mlaParams_.n; + this->num_core_ = mlaParams_.rmsNumCore1; + this->num_col_1 = mlaParams_.rmsNumCol1; + this->num_col_2 = mlaParams_.rmsNumCol2; + this->num_row = mlaParams_.n; + this->epsilon_ = 1e-6; + this->mlaParams = mlaParams_; + } + + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, + GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, + GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, + GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, + GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR bias2Gm, GM_ADDR wukGm, + GM_ADDR descale1Gm, GM_ADDR descale2Gm, GM_ADDR gmCtkvScale, GM_ADDR gmQnopeScale, + GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm, + GM_ADDR s2Gm, GM_ADDR s3Gm, GM_ADDR s4Gm, GM_ADDR s5Gm) + { + quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmCtkvScale)); + gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma3Gm)); + sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin1Gm)); + cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos1Gm)); + keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ K_NOPE_DTYPE *>(keycacheOutGm)); + keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm2)); + slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); + descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale1Gm)); + s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s2Gm)); + s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s3Gm)); + s5GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(s5Gm)); + +#ifdef __DAV_C220_CUBE__ + mm_w8a8_aic_1.Init(s1Gm, wdqkvGm, s2Gm, mlaParams.mm1, 0); + mm_w8a8_aic_1.PreloadWeight(); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aic_2.Init(s1Gm, wuqGm, s2Gm, mlaParams.mm2, 1); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + mm_w8a8_aic_2.Init(s1Gm, wuqGm, s3Gm, mlaParams.mm2, 1); + } + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + mm_ein_sum.Init(s4Gm, wukGm, s1Gm, mlaParams); + } else { + mm_ein_sum.Init(s4Gm, wukGm, qGm, mlaParams); + } +#endif + + hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); + gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm)); + quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm)); + quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); + wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); + gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma2Gm)); + quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale2Gm)); + quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm)); + sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin2Gm)); + cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos2Gm)); + wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm)); + wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wukGm)); + descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale2Gm)); + s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm)); + s4GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s4Gm)); + qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ Q_OUT_DTYPE *>(qGm)); + qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); + bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); + bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); + beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm)); + beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); + +#ifdef __DAV_C220_VEC__ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aiv_1.Init(s2Gm, s3Gm, descale1Gm, bias1Gm, s5Gm, mlaParams.mm1); + mm_w8a8_aiv_2.Init(s2Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + mm_w8a8_aiv_2.Init(s3Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); + } + row_work = (num_row + num_core_ - 1) / num_core_; + row_work_ = 0; + uint32_t need_core = (num_row + row_work - 1) / row_work; + if (vectorBlockIdx < need_core - 1) { + row_work_ = row_work; + } else if (vectorBlockIdx == need_core - 1) { + row_work_ = num_row - (need_core - 1) * row_work; + } else { + row_work_ = 0; + } + this->splitN = mlaParams.perTaskNum; + Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float), + descale1Gm, hiddenStateGm, s1Gm, 0, num_col_1, + vectorBlockIdx * static_cast(row_work) * num_col_1, + vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE, + num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE, + num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + } + ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); + einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); +#endif + } + + __aicore__ inline void ProcessCube(); + + __aicore__ inline void ProcessVector(); + +private: + constexpr static uint32_t C0_SIZE = 16; + constexpr static uint32_t I8_C0_SIZE = 32; + + template + __aicore__ inline void RmsNormAndRopeConvergence1( + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &sinTensor, const AscendC::LocalTensor &cosTensor, + const AscendC::LocalTensor &slotMappingTensor, const uint32_t sN, + const AscendC::LocalTensor &rmsNormTensor, const AscendC::LocalTensor &gammaFp32, + const AscendC::LocalTensor &ropeKTensor, const AscendC::LocalTensor &ropeKRevertTensor, + const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor, + AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) + { + int64_t slotMapGmOffset = vectorBlockIdx * row_work; + AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], + AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + mmTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE]; + deScaleTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE * 2]; + AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + } + SET_FLAG(MTE2, V, EVENT_ID2); + WAIT_FLAG(MTE2, V, EVENT_ID2); + SET_FLAG(MTE2, S, EVENT_ID2); + WAIT_FLAG(MTE2, S, EVENT_ID2); + for (uint64_t loop = 0; loop < sN; ++loop) { + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); + if (slotValue == -1) { + continue; + } + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(srcTensor, s3GmTensor[offset], + AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + } + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + SET_FLAG(MTE2, V, EVENT_ID0); + // ND + uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + // NZ + uint32_t outer_idx = slotValue / 128; + uint32_t inner_idx = slotValue % 128; + + SET_FLAG(S, MTE3, EVENT_ID0); + /* RmsNorm start */ + WAIT_FLAG(MTE2, V, EVENT_ID0); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + /* DeQuant */ + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + } + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], + SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(V, S, EVENT_ID1); + WAIT_FLAG(V, S, EVENT_ID1); + float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + SET_FLAG(S, V, EVENT_ID1); + WAIT_FLAG(S, V, EVENT_ID1); + AscendC::PipeBarrier(); + Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + // quant + Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + } else { + AscendC::PipeBarrier(); + if (std::is_same::value) { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + } else { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + } + } + /* RmsNorm end */ + /* Rope K start */ + uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; + Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + revertOffset); + Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + revertOffset); + Duplicate(calTensor, static_cast(-1), revertOffset); + Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); + AscendC::PipeBarrier(); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + if (std::is_same::value) { + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + SPLIT_RMSNRORM_SIZE_TWO); + } else { + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + } + AscendC::PipeBarrier(); + /* Rope K end */ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + } else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + // nope:int8 nz + AscendC::DataCopyExtParams outExt; + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); + outExt.srcStride = 0; + outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); + DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) { + uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + // nope:T1 nz + AscendC::DataCopyExtParams outExt; + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else { + // keycache1 + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + // keycache2 + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], + SPLIT_RMSNRORM_SIZE_TWO); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + } + +private: + uint32_t n; + uint32_t splitN; + uint32_t rotaryCoeff; + uint32_t blockIdx; + uint32_t sub_block_idx; + uint32_t vectorBlockIdx; + uint32_t blockOffset; + uint32_t perTaskNum; + uint32_t resTaskNum; + MlaTilingData mlaParams; + + uint32_t num_core_; + uint32_t num_col_1; + uint32_t num_col_2; + float epsilon_; + uint32_t num_row; + uint32_t quantMin_; + uint32_t row_work; + uint32_t row_work_; + + AsdopsBuffer buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor hiddenStateGmTensor; + + AscendC::GlobalTensor gamma1GmTensor; + AscendC::GlobalTensor quantScale1GmTensor; + AscendC::GlobalTensor quantOffset1GmTensor; + + AscendC::GlobalTensor wdqkvGmTensor; + AscendC::GlobalTensor gamma2GmTensor; + AscendC::GlobalTensor quantScale2GmTensor; + AscendC::GlobalTensor quantScale3GmTensor; + AscendC::GlobalTensor quantOffset2GmTensor; + AscendC::GlobalTensor gamma3GmTensor; + AscendC::GlobalTensor sin1GmTensor; + AscendC::GlobalTensor cos1GmTensor; + AscendC::GlobalTensor sin2GmTensor; + AscendC::GlobalTensor cos2GmTensor; + AscendC::GlobalTensor keycacheGmTensor1; + AscendC::GlobalTensor keycacheGmTensor2; + AscendC::GlobalTensor slotMappingGmTensor; + AscendC::GlobalTensor wuqGmTensor; + AscendC::GlobalTensor wukGmTensor; + + // cachemode2-->int8; else bf16 + AscendC::GlobalTensor qGmTensor; + AscendC::GlobalTensor qGmTensor2; + AscendC::GlobalTensor s1GmTensor; + AscendC::GlobalTensor s2GmTensor; + AscendC::GlobalTensor s3GmTensor; + AscendC::GlobalTensor s4GmTensor; + AscendC::GlobalTensor s5GmTensor; + AscendC::GlobalTensor descale1gmTensor; + AscendC::GlobalTensor descale2gmTensor; + AscendC::GlobalTensor beta1GmTensor; + AscendC::GlobalTensor beta2GmTensor; + + AscendC::GlobalTensor bias1gmTensor; + AscendC::GlobalTensor bias2gmTensor; + +#ifdef __DAV_C220_CUBE__ + PpMatmulW8a8Aic mm_w8a8_aic_1; + PpMatmulW8a8Aic mm_w8a8_aic_2; + PpMatmulEinSum mm_ein_sum; +#endif + +#ifdef __DAV_C220_VEC__ + PpMatmulW8a8Aiv mm_w8a8_aiv_1; + PpMatmulW8a8Aiv mm_w8a8_aiv_2; + + Quant Quant1; + + RmsNormQuant rmsNormQuant2; + RopeFp16 ropeFp16; + EinSumQuant einSumQuant; +#endif +}; + +template +__aicore__ inline void +MLAOperation::ProcessCube() +{ +#ifdef __DAV_C220_CUBE__ + mm_w8a8_aic_1.Process(); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + FftsCrossCoreSync(MMAIC); + WaitFlagDev(MMAIC); + FftsCrossCoreSync(MMAIV); + } + + mm_w8a8_aic_2.PreloadWeight(); + mm_w8a8_aic_2.Process(); + mm_ein_sum.Process(); + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + FftsCrossCoreSync(EINSUMOUT); + WaitFlagDev(EINSUMOUT); + FftsCrossCoreSync(EINSUMQUANT); + } +#endif +} + +template +__aicore__ inline void +MLAOperation::ProcessVector() +{ +#ifdef __DAV_C220_VEC__ + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); + AscendC::LocalTensor res1_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64); + Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); + } + FftsCrossCoreSync(QUANT1); + WaitFlagDev(QUANT1); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM1_START); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aiv_1.Process(); + FftsCrossCoreSync(RMSNORMQUANT2); + WaitFlagDev(RMSNORMQUANT2); + } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + WaitFlagDev(MMAIV); + } + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor beta_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, + res1_tensor, res3_tensor); + } + FftsCrossCoreSync(MM2); + WaitFlagDev(MM2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM2_START); + + if (row_work_ != 0) { + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor sin_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + AscendC::LocalTensor cos_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + int32_t rms3_ub_offset = + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); + + int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + + 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + + MM1_OUT_SIZE * 4 * 2 + 32; + AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); + + AscendC::LocalTensor tmpfp16; + AscendC::LocalTensor int8OutTensor; + float scale3 = 0; + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + // quantScale3 + AscendC::LocalTensor quantScaleTensor = + buf.GetBuffer(rms3_ub_offset); + AscendC::LocalTensor floatQuantScaleTensor = + buf.GetBuffer(rms3_ub_offset + 32); + // int8out + tmpfp16 = buf.GetBuffer(rms3_ub_offset + + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + int8OutTensor = buf.GetBuffer(out_ub_offset); + AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); + } + + RmsNormAndRopeConvergence1( + input_tensor, // n * 576 + gamma_tensor, // gamma + sin_tensor, // sin + cos_tensor, // cons + slotMapping_tensor, // slotMapping + row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + + SPLIT_RMSNRORM_SIZE_TWO], + temp_tensor, tmpfp16, int8OutTensor, scale3); + } + mm_w8a8_aiv_2.Process(); + FftsCrossCoreSync(MM2OUT); + WaitFlagDev(MM2OUT); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM3_START); + ropeFp16.Process(); + + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + WaitFlagDev(EINSUMQUANT); + einSumQuant.Process(); + } +#endif +} + +} // namespace MLAPO_BF16 diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp new file mode 100644 index 0000000000..097fbc2ad1 --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -0,0 +1,2508 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include "kernel/common.h" +#include "kernel/iterator.h" +#include "kernel/mem.h" +#include "kernel/mma.h" +#include "kernel/utils.h" +#include "kernel/simd.h" +#include "kernel/kernel_utils.h" + +#include "lib/matmul_intf.h" + +#include "mla_preprocess.h" +#include "../op_host/tiling/mla_preprocess_tiling.h" + +namespace MLAPO_FP16 { + +template +class RopeFp16 +{ +public: + __aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {} + + __aicore__ inline void RopeInit(AscendC::GlobalTensor &qGm, AscendC::GlobalTensor &cosGm, + AscendC::GlobalTensor &sinGm, + AscendC::GlobalTensor &outRopeConcatGm, + AscendC::GlobalTensor &outRopeConcatGm2, + const MlaTilingData &ropeConcatParams) + { + this->qGm_ = qGm; + this->cosGm_ = cosGm; + this->sinGm_ = sinGm; + this->outRopeConcatGm_ = outRopeConcatGm; + this->outRopeConcatGm2_ = outRopeConcatGm2; + + headDim = ropeConcatParams.headDim; + headNumQ = ropeConcatParams.headNumQ; + rotaryCoeff = ropeConcatParams.rotaryCoeff; + ntokens = ropeConcatParams.ntokens; + realCore = ropeConcatParams.realCore; + nlCoreRun = ropeConcatParams.nlCoreRun; + lCoreRun = ropeConcatParams.lCoreRun; + maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb; + preCoreLoopTime = ropeConcatParams.preCoreLoopTime; + preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast; + lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; + lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; + concatSize = ropeConcatParams.concatSize; + blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); + loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; + lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; + this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32) + this->rotateStride_ = this->headDim / this->rotaryCoeff; + headBlockLen = static_cast(this->headDim / ELE_NUM_FP16); + headBlockLenFP32 = static_cast(this->headDim / ELE_NUM_FP32); + rotaryLen = static_cast(this->rotateStride_ / ELE_NUM_FP32); + concatBlockLen = static_cast(this->concatSize / ELE_NUM_FP16); + outLineOffset = this->headDim + this->concatSize; + uint32_t dataNum = this->headDim * this->maxNPerLoopForUb; + dataSizeFp16 = dataNum * sizeof(QkDtype); + dataSizeFp32 = dataNum * sizeof(float); + uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb; + } + + __aicore__ inline void Process() + { + if (blockIdx_ >= realCore) return; + uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun; + // [maxNPerLoopForUb,head_dim] 的 neg + AscendC::LocalTensor negLocal = + buf.GetBuffer(dataSizeFp32 * 4 + dataSizeFp16 * 3); + ExpandNeg(negLocal, this->maxNPerLoopForUb); + + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t zz = 0; zz < this->loopTime; ++zz) { + uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb; + uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb; + uint64_t endHead = startHead + loopN; + + // move in Q + AscendC::LocalTensor inputQ = buf.GetBuffer(0); + AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); + AscendC::LocalTensor reverseQ = + buf.GetBuffer(dataSizeFp32 + dataSizeFp16); + uint64_t qOffset = startHead * 192 + 128; + CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); + + // move in cos/sin + AscendC::LocalTensor inputCos = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16); + AscendC::LocalTensor inputSin = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 2); + uint64_t startSinCosHeadIndex = startHead; + uint64_t headRemain = startHead % this->headNumQ; + uint64_t localStartAddr = 0; + if (headRemain != 0) { + uint64_t preProcessHeadNum = this->headNumQ - headRemain; + uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum; + CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim, + needToProcesHead); + startSinCosHeadIndex += needToProcesHead; + localStartAddr += needToProcesHead * this->headDim; + } + // Iterate through the remaining data. + if (startSinCosHeadIndex < endHead) { + uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ; + uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ; + for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) { + // Mantissa + uint32_t repeatNum = + index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ; + CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum); + localStartAddr += this->headDim * this->headNumQ; + } + } + AscendC::LocalTensor inputCosCastFP32 = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 3); + AscendC::LocalTensor inputSinCastFP32 = + buf.GetBuffer(dataSizeFp32 * 3 + dataSizeFp16 * 3); + AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + + // rope result + uint32_t repeatTime = this->headDim * loopN; + AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime); + + AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime); + AscendC::PipeBarrier(); + + // move out rope result + // cast fp16/bf16 + AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim); + AscendC::PipeBarrier(); + uint64_t outQOffset = startHead * outLineOffset + this->concatSize; + uint64_t outQOffset2 = startHead * this->headDim; + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + if constexpr (CacheMode == CACHE_MODE_KVCACHE) { + AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen}); + } else { + AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + // tensor -1 -1 -1 1 1 1 + template + __aicore__ inline void ExpandNeg(const AscendC::LocalTensor &tempBuf, uint32_t headNumTemp) + { + for (uint32_t i = 0; i < this->rotateStride_; ++i) { + tempBuf.SetValue(i, (BUF_TYPE)-1); + tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1); + } + SET_FLAG(S, V, EVENT_ID1); + WAIT_FLAG(S, V, EVENT_ID1); + AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0}); + } + + template + __aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor &tempBufQ, + const AscendC::LocalTensor &tempBufQCast, + const AscendC::LocalTensor &tempBufRverseQ, uint64_t qOffset, + uint16_t loopN) + { + SET_FLAG(S, MTE2, EVENT_ID1); + WAIT_FLAG(S, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + // move in Q + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // cast fp32 + AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + // move in reverseQ + AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyCosSin(const AscendC::LocalTensor &tempBufCos, + const AscendC::LocalTensor &tempBufSin, uint64_t localStartAddr, + uint64_t gmStartAddr, uint64_t repeatNum) + { + SET_FLAG(S, MTE2, EVENT_ID1); + WAIT_FLAG(S, MTE2, EVENT_ID1); + AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::PipeBarrier(); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor qGm_; + AscendC::GlobalTensor cosGm_; + AscendC::GlobalTensor sinGm_; + AscendC::GlobalTensor outRopeConcatGm_; + AscendC::GlobalTensor outRopeConcatGm2_; + + uint32_t repeatSize_{0}; + uint32_t rotateStride_{0}; // this->headDim / rope_conf + uint32_t headDim; + uint32_t headNumQ; + uint32_t rotaryCoeff; + uint32_t ntokens; + uint32_t realCore; + uint32_t nlCoreRun; + uint32_t lCoreRun; + uint32_t maxNPerLoopForUb; + uint32_t preCoreLoopTime; + uint32_t preCoreLoopNLast; + uint32_t lastCoreLoopTime; + uint32_t lastCoreLoopNLast; + uint32_t concatSize; + uint32_t blockIdx_; + uint32_t loopTime{0}; // The number of current data rounds + uint32_t lastLoopN{0}; // The number of lines currently processed by tails kernel + + uint32_t dataSizeFp32; + uint32_t dataSizeFp16; + uint16_t headBlockLen{0}; + uint16_t headBlockLenFP32{0}; + uint16_t rotaryLen{0}; + uint16_t concatBlockLen{0}; + uint64_t outLineOffset{0}; +}; + +__aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor &dst_local, + const AscendC::LocalTensor &src_local, + const AscendC::LocalTensor &work_local, int32_t count) +{ +#ifdef __DAV_C220_VEC__ + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + AscendC::PipeBarrier(); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + AscendC::PipeBarrier(); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + AscendC::PipeBarrier(); + } + AscendC::AscendCUtils::SetMask(NUM_PER_REP_FP32); + cadd_v(dst_local, // dst + work_local, // src + 1, // repeat + 0, // dstRepeatStride + 1, // srcBlockStride + 0); // srcRepeatStride + AscendC::PipeBarrier(); +#endif +} + +template +class Quant +{ +public: + __aicore__ inline Quant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor gammaGmTensor, AscendC::GlobalTensor betaGmTensor, + AscendC::GlobalTensor quantScaleGmTensor, + AscendC::GlobalTensor quantOffsetGmTensor, + AscendC::GlobalTensor inputGmTensor, AscendC::GlobalTensor outputGmTensor, + uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset, + uint64_t gm_out_offset, uint32_t row_work_, const MlaTilingData &mlaParams_) + { + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->inputGmTensor = inputGmTensor; + this->outputGmTensor = outputGmTensor; + num_col_ = num_col; + quantMin_ = -128; + uint32_t num_row = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + + SET_FLAG(MTE2, V, EVENT_ID1); + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, + AscendC::DataCopyParams(1, 1, 0, 0)); // 7168 * 2 + 7168 * 2 + 32 + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, + AscendC::DataCopyParams(1, 1, 0, 0)); // 7168 * 2 + 7168 * 2 + 64 + SET_FLAG(MTE2, S, EVENT_ID0); + + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; // + offset + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if (pid > 0) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); // 7168 * 2 + SET_FLAG(MTE2, V, EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID0); + + // modify input + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + + if (pid == 0) { + WAIT_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + } + + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + AscendC::PipeBarrier(); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + + uint32_t num_col_{0}; + uint32_t row_work{0}; + uint32_t row_work_{0}; + uint32_t row_step_{0}; + uint32_t row_tail_{0}; + uint64_t gm_offset_{0}; + uint64_t gm_out_offset_{0}; + float avg_factor_{1.0}; // 1/num_col_ + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +template +class RmsNormQuant +{ +public: + __aicore__ inline RmsNormQuant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor gammaGmTensor, AscendC::GlobalTensor betaGmTensor, + AscendC::GlobalTensor quantScaleGmTensor, + AscendC::GlobalTensor quantOffsetGmTensor, + AscendC::GlobalTensor inputGmTensor, AscendC::GlobalTensor outputGmTensor, + uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset, + uint64_t gm_out_offset, uint32_t row_work_, const MlaTilingData &mlaParams_) + { + this->gammaGmTensor = gammaGmTensor; + this->betaGmTensor = betaGmTensor; + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->inputGmTensor = inputGmTensor; + this->outputGmTensor = outputGmTensor; + num_col_ = num_col; + avg_factor_ = avg_factor; + epsilon_ = 1e-6; + quantMin_ = -128; + uint32_t num_row = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->gammaTensor = gammaTensor; + this->betaTensor = betaTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE * num_col_align_withStride_fp32]; // 4 + + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + + AscendC::DataCopy( + gammaTensor, gammaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); // 7168 * 2 + 7168 * 2 + AscendC::DataCopy( + betaTensor, betaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); // 7168 * 2 + 7168 * 2 + SET_FLAG(MTE2, V, EVENT_ID1); + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, + AscendC::DataCopyParams(1, 1, 0, 0)); // 7168 * 2 + 7168 * 2 + 32 + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, + AscendC::DataCopyParams(1, 1, 0, 0)); // 7168 * 2 + 7168 * 2 + 64 + SET_FLAG(MTE2, S, EVENT_ID0); + + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; // + offset + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if (pid > 0) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); // 7168 * 2 + SET_FLAG(MTE2, V, EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID0); + + // modify input + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_); + AscendC::PipeBarrier(); + ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + Adds(sum, sum, epsilon_, 1); + AscendC::PipeBarrier(); + Sqrt(sum, sum, 1); + SET_FLAG(V, S, EVENT_ID0); + WAIT_FLAG(V, S, EVENT_ID0); + float factor = 1 / sum.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + + if (pid == 0) { + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE, + REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } + + Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + if constexpr (WITH_BETA) { // quant beta is fp16 add + AscendC::LocalTensor b = this->betaTensor; + Cast(work, b, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Add(fp32_xy, fp32_xy, work, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + AscendC::PipeBarrier(); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor gammaTensor; + AscendC::LocalTensor betaTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + + AscendC::GlobalTensor gammaGmTensor; + AscendC::GlobalTensor betaGmTensor; + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + + uint32_t num_col_{0}; + uint32_t row_work{0}; + uint32_t row_work_{0}; + uint32_t row_step_{0}; + uint32_t row_tail_{0}; + uint64_t gm_offset_{0}; + uint64_t gm_out_offset_{0}; + float avg_factor_{1.0}; + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +__aicore__ __force_inline__ uint64_t Min(const uint64_t a, const uint64_t b) +{ + return a < b ? a : b; +} + +__aicore__ __force_inline__ uint64_t Max(const uint64_t a, const uint64_t b) +{ + return a > b ? a : b; +} + +template +__aicore__ __force_inline__ uint64_t RoundUp(const uint64_t val) +{ + return (val + Base - 1) / Base * Base; +} + +template +__aicore__ __force_inline__ uint64_t CeilDiv(const uint64_t dividend) +{ + return (dividend + Divisor - 1) / Divisor; +} + +template +class EinSumQuant +{ +public: + __aicore__ explicit EinSumQuant() {} + + __aicore__ inline void Init(GM_ADDR einSumOutGm, GM_ADDR scaleGm, GM_ADDR quantOutGm, + const MlaTilingData &tilingData) + { + einSumOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(einSumOutGm)); + scaleGm_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(scaleGm)); + quantOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOutGm)); + + headNum = tilingData.esqHeadNum; + colNum = tilingData.esqColNum; + ubHeadLoop = tilingData.esqUbHeadLoop; + headPerLoop = tilingData.esqHeadPerLoop; + headTail = tilingData.esqHeadTail; + colLoop = tilingData.esqColLoop; + colTail = tilingData.esqColTail; + + currentIdx = (AscendC::GetBlockIdx() / 2) * 2 + GetSubBlockidx(); + if (currentIdx < tilingData.esqFrontCore) { + batchNum = tilingData.esqFrontCoreBatch; + currentCoreStartOffset = currentIdx * tilingData.esqFrontCoreBatch * headNum * colNum; + } else { + batchNum = tilingData.esqTailCoreBatch; + currentCoreStartOffset = (tilingData.esqFrontCore * tilingData.esqFrontCoreBatch + + (currentIdx - tilingData.esqFrontCore) * tilingData.esqTailCoreBatch) * + headNum * colNum; + } + + // calc tensors' data size(bytes) + inputDataSize = headPerLoop * colNum * sizeof(InDtype); + scaleDataSize = headPerLoop * sizeof(ScaleDtype); + scaleBrcbFp16DataSize = headPerLoop * ELE_NUM_FP16 * sizeof(half); + tempQuantFp16DataSize = inputDataSize; + int8OutDataSize = headPerLoop * colNum; + headTailDataSize = headTail * colNum * sizeof(InDtype); + int8TailOutDataSize = headTail * colNum; + } + + __aicore__ inline void Process() + { + if (batchNum == 0) { + return; + } + // init local tensor + inputTensor_ = buf.GetBuffer(0); + scaleTensor_ = buf.GetBuffer(inputDataSize); + scaleBrcbFp16_ = buf.GetBuffer(inputDataSize + scaleDataSize); + tempQuantFp16_ = + buf.GetBuffer(inputDataSize + scaleDataSize + scaleBrcbFp16DataSize); + int8OutTensor_ = buf.GetBuffer(inputDataSize + scaleDataSize + + scaleBrcbFp16DataSize + tempQuantFp16DataSize); + + uint64_t inputLoopOffset = 0; + uint32_t scaleLoopOffset = 0; + uint64_t batchOffset = 0; + uint64_t calcStartOffset = 0; + uint64_t colOffset = 0; + uint8_t calcRepeatStride = static_cast(colNum / ELE_NUM_FP16); + + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t ubLoopIdx = 0; ubLoopIdx < ubHeadLoop; ubLoopIdx++) { + // scale CopyIn + scaleLoopOffset = ubLoopIdx * headPerLoop; + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(scaleTensor_, scaleGm_[scaleLoopOffset], headPerLoop); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // scale broadcast [H', 1] --> [H', 16] + AscendC::Brcb(scaleBrcbFp16_, scaleTensor_, headPerLoop / 8, {1, 8}); + AscendC::PipeBarrier(); + + inputLoopOffset = ubLoopIdx * headPerLoop * colNum; + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) { + batchOffset = batchIdx * headNum * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + // input CopyIn + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(inputTensor_, einSumOutGm_[calcStartOffset], + {1, static_cast(inputDataSize / BLOCK_SIZE_32), 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_128; + AscendC::Mul(tempQuantFp16_[colOffset], inputTensor_[colOffset], scaleBrcbFp16_, CONST_128, + headPerLoop, {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + + // quant fp16 --> int8 + CastFromF16ToI8(int8OutTensor_, tempQuantFp16_, quantMin_, headPerLoop * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_, + {1, static_cast(int8OutDataSize / BLOCK_SIZE_32), 0, 0}); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + + // deal with headTail + padLen = (headTail + ELE_NUM_FP16 - 1) / ELE_NUM_FP16 * ELE_NUM_FP16; + SET_FLAG(MTE3, MTE2, EVENT_ID1); + if (headTail > 0) { + // scale CopyIn + scaleLoopOffset = ubHeadLoop * headPerLoop; + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + if (headTail == padLen) { + AscendC::DataCopy(scaleTensor_, scaleGm_[scaleLoopOffset], headTail); + } else { + AscendC::DataCopyExtParams copyParams{1, static_cast(headTail * sizeof(half)), 0, 0, 0}; + AscendC::DataCopyPadExtParams padParams{true, 0, static_cast(padLen - headTail), 0}; + AscendC::DataCopyPad(scaleTensor_, scaleGm_[scaleLoopOffset], copyParams, padParams); + } + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // scale broadcast [H', 1] --> [H', 16] + AscendC::Brcb(scaleBrcbFp16_, scaleTensor_, padLen / 8, {1, 8}); + AscendC::PipeBarrier(); + + inputLoopOffset = ubHeadLoop * headPerLoop * colNum; + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) { + batchOffset = batchIdx * headNum * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + // input CopyIn + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(inputTensor_, einSumOutGm_[calcStartOffset], + {1, static_cast(headTailDataSize / BLOCK_SIZE_32), 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_128; + AscendC::Mul(tempQuantFp16_[colOffset], inputTensor_[colOffset], scaleBrcbFp16_, CONST_128, + headTail, {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + + // quant fp16 --> int8 + CastFromF16ToI8(int8OutTensor_, tempQuantFp16_, quantMin_, headTail * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_, + {1, static_cast(int8TailOutDataSize / BLOCK_SIZE_32), 0, 0}); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor einSumOutGm_; + AscendC::GlobalTensor scaleGm_; + AscendC::GlobalTensor quantOutGm_; + + AscendC::LocalTensor inputTensor_; + AscendC::LocalTensor scaleTensor_; + AscendC::LocalTensor scaleBrcbFp16_; + AscendC::LocalTensor tempQuantFp16_; + AscendC::LocalTensor int8OutTensor_; + + // [batchNum, headNum, colNum] + uint32_t batchNum; + uint32_t headNum; + uint32_t colNum; + // ub loop + uint32_t ubHeadLoop; + uint32_t headPerLoop; + uint32_t headTail; + // col loop + uint32_t colLoop; + uint32_t colTail; + + uint32_t currentIdx; + uint64_t currentCoreStartOffset; + uint32_t inputDataSize; // bytes + uint32_t scaleDataSize; + uint32_t scaleBrcbFp16DataSize; + uint32_t tempQuantFp16DataSize; + uint32_t int8OutDataSize; + uint32_t headTailDataSize; + uint32_t int8TailOutDataSize; + + half quantMin_{-128}; + uint32_t padLen; +}; + +#ifdef __DAV_C220_CUBE__ + +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +template +class PpMatmulEinSum +{ + using InDtype = half; + using OutDtype = half; + using AccumDtype = float; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; + static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; + static constexpr uint32_t CONST_16 = 16; + static constexpr uint32_t CONST_256 = 256; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams); + + __aicore__ __force_inline__ void Process(); + __aicore__ __force_inline__ void PreloadB(); + +private: + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx); + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx); + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round); + __aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round); + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k = 0; + uint32_t ping_flag{0}; +}; + +template +__aicore__ __force_inline__ void PpMatmulEinSum::Init( + GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams) +{ +#ifdef __DAV_C220_CUBE__ + batch_size = mlaParams.mm3.numBatch; + m = mlaParams.mm3.m; + k = mlaParams.mm3.k; + n = mlaParams.mm3.n; + m0 = mlaParams.mm3.m0; + k0 = mlaParams.mm3.k0; + n0 = mlaParams.mm3.n0; + tdim.m = mlaParams.mm3.mLoop; + tdim.k = mlaParams.mm3.kLoop; + tdim.n = mlaParams.mm3.nLoop; + core_loop = mlaParams.mm3.coreLoop; + swizzle_cnt = mlaParams.mm3.swizzleCount; + num_core = mlaParams.mm3.blockDim; + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(0); + l1_base_b = buf.template GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); +#endif + return; +} + +template +__aicore__ __force_inline__ void +PpMatmulEinSum::GetBaseBlockIdx(uint64_t index, MatCoord &tidx) +{ + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (swizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (swizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + return; +} + +template +__aicore__ __force_inline__ void PpMatmulEinSum::PreloadB() +{ +#ifdef __DAV_C220_CUBE__ + uint64_t batch_idx = core_idx / tdim.n / tdim.m; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + MatCoord tidx{0}; + GetBaseBlockIdx(core_idx, tidx); + uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n); + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + SET_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); +#endif +} + +template +__aicore__ __force_inline__ uint64_t PpMatmulEinSum::GetOffsetB( + const uint64_t batchIdx, const uint64_t kIdx, const uint64_t nIdx) +{ + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + return batchIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return batchIdx * k * n + kIdx * k0 * n + nIdx * n0; + } + } else { + if constexpr (transB) { + return batchIdx * RoundUp(n) * RoundUp(k) + kIdx * k0 * RoundUp(n) + + nIdx * n0 * CONST_16; + } else { + return batchIdx * RoundUp(k) * RoundUp(n) + nIdx * n0 * RoundUp(k) + + kIdx * k0 * CONST_16; + } + } +} + +template +__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileA( + AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) +{ + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, // nTileActual + CONST_16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + (k + splitGapA) * batch_size); // dVal + } +} + +template +__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileB( + AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) +{ + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } +} + +template +__aicore__ __force_inline__ void PpMatmulEinSum::Process() +{ +#ifdef __DAV_C220_CUBE__ + if (block_idx >= num_core) { + WaitFlagDev(MM2OUT); + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(BMM3SPLIT); + return; + } + using LocalTensor = AscendC::LocalTensor; + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx, tidx); + uint64_t offset_a = 0, offset_b = 0, offset_a_next = 0, offset_b_next = 0; + uint64_t offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0; + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + offset_a = tidx.m * m0 * batch_size * (k + splitGapA) + batch_idx * (k + splitGapA) + shuffle_k * k0; + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + WaitFlagDev(MM2OUT); + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(BMM3SPLIT); + + // Copy A from gm to l1 buffer + WAIT_FLAG(MTE1, MTE2, event_id); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + SET_FLAG(MTE2, MTE1, event_id); + + WAIT_FLAG(MTE1, MTE2, event_id + 2); + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + offset_a_next = + tidx.m * m0 * batch_size * (k + splitGapA) + batch_idx * (k + splitGapA) + shuffle_k_next * k0; + offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n); + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + offset_a_next = + tidx.m * m0 * batch_size * (k + splitGapA) + b_idx_next * (k + splitGapA) + shuffle_k_next * k0; + offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n); + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next, k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next, n_round_next); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if constexpr (transB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + + AscendC::PipeBarrier(); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + (n + splitGapC) * batch_size); // nActual + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); +#endif +} + +template +class PpMatmulW8a8 +{ + using InDtype = int8_t; + using OutDtype = half; + using AccumDtype = int32_t; + using BiasDtype = int32_t; + using ScaleDtype = uint64_t; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mmad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; + static constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; + static constexpr uint64_t BLOCK_SIZE_16 = 16; + static constexpr uint64_t BLOCK_SIZE_32 = 32; + static constexpr uint64_t CUBE_MATRIX_SIZE_512 = 512; + static constexpr uint64_t FB_BUFF_SIZE = 1024 * 7; + static constexpr uint64_t SCALE_L1_LEN = 4096; + static constexpr uint64_t BIAS_L1_LEN = 2048; + static constexpr uint64_t CONST_4 = 4; + static constexpr uint64_t CONST_32 = 32; + static constexpr uint64_t CONST_64 = 64; + static constexpr uint64_t CONST_128 = 128; + +public: + __aicore__ PpMatmulW8a8() {}; + + __aicore__ __force_inline__ void Init(AscendC::GlobalTensor &gm_a, AscendC::GlobalTensor &gm_b, + AscendC::GlobalTensor &gm_bias, + AscendC::GlobalTensor &gm_descale, + AscendC::GlobalTensor &gm_c, MlaTilingData &mlaParams, + uint32_t mode); + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t batchIdx, const uint64_t mIdx, uint64_t kIdx); + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t batchIdx, const uint64_t kIdx, uint64_t nIdx); + __aicore__ __force_inline__ void CopyTileA(const AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round); + __aicore__ __force_inline__ void CopyTileB(const AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round); + __aicore__ __force_inline__ void Process(); + __aicore__ __force_inline__ void PreloadDoubleWeight(); + +private: + __aicore__ __force_inline__ void InitBuffer(); + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, uint64_t &m_idx, uint64_t &n_idx); + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_bias; + AscendC::GlobalTensor gm_descale; + AscendC::GlobalTensor gm_c; + + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + AscendC::LocalTensor bias_l1; + AscendC::LocalTensor scale_l1; + AscendC::LocalTensor scale_fb; + + uint64_t bias_bt{0}; + uint32_t core_num{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t m_loop{0}; + uint32_t n_loop{0}; + uint32_t k_loop{0}; + uint32_t core_loop{0}; + uint32_t core_idx{0}; + uint32_t ping_flag{0}; + uint32_t swizzle_cnt{1}; + uint32_t en_shuffle_k{0}; + uint64_t b0mat_pingpong_buffer_len{0}; + bool load_all_Amat_flag{false}; + uint32_t MM1_MM2_mode{0}; +}; + +template +__aicore__ __force_inline__ void PpMatmulW8a8::Init( + AscendC::GlobalTensor &gm_a, AscendC::GlobalTensor &gm_b, + AscendC::GlobalTensor &gm_bias, AscendC::GlobalTensor &gm_descale, + AscendC::GlobalTensor &gm_c, MlaTilingData &mlaParams, uint32_t mode) +{ + this->gm_a = gm_a; + this->gm_b = gm_b; + this->gm_bias = gm_bias; + this->gm_descale = gm_descale; + this->gm_c = gm_c; + MM1_MM2_mode = mode; + if (mode == 0) { + batch_size = mlaParams.mm1.numBatch; + m = mlaParams.mm1.m; + k = mlaParams.mm1.k; + n = mlaParams.mm1.n; + m0 = mlaParams.mm1.m0; + k0 = mlaParams.mm1.k0; + n0 = mlaParams.mm1.n0; + m_loop = mlaParams.mm1.mLoop; + k_loop = mlaParams.mm1.kLoop; + n_loop = mlaParams.mm1.nLoop; + core_loop = mlaParams.mm1.coreLoop; + swizzle_cnt = mlaParams.mm1.swizzleCount; + en_shuffle_k = mlaParams.mm1.enShuffleK; + core_num = mlaParams.mm1.blockDim; + load_all_Amat_flag = mlaParams.mm1.enLoadAllAmat; + b0mat_pingpong_buffer_len = mlaParams.mm1.b0matPingPongBufferLen; + } else { + batch_size = mlaParams.mm2.numBatch; + m = mlaParams.mm2.m; + k = mlaParams.mm2.k; + n = mlaParams.mm2.n; + m0 = mlaParams.mm2.m0; + k0 = mlaParams.mm2.k0; + n0 = mlaParams.mm2.n0; + m_loop = mlaParams.mm2.mLoop; + k_loop = mlaParams.mm2.kLoop; + n_loop = mlaParams.mm2.nLoop; + core_loop = mlaParams.mm2.coreLoop; + swizzle_cnt = mlaParams.mm2.swizzleCount; + en_shuffle_k = mlaParams.mm2.enShuffleK; + core_num = mlaParams.mm2.blockDim; + load_all_Amat_flag = mlaParams.mm2.enLoadAllAmat; + b0mat_pingpong_buffer_len = mlaParams.mm2.b0matPingPongBufferLen; + } + + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + + InitBuffer(); + return; +} + +template +__aicore__ __force_inline__ uint64_t PpMatmulW8a8::GetOffsetA( + const uint64_t batch_idx, const uint64_t m_idx, uint64_t k_idx) +{ + if constexpr (transA) { + return batch_idx * m * k + k_idx * k0 * m + m_idx * m0; + } else { + return batch_idx * m * k + m_idx * m0 * k + k_idx * k0; + } +} + +template +__aicore__ __force_inline__ uint64_t PpMatmulW8a8::GetOffsetB( + const uint64_t batch_idx, const uint64_t k_idx, uint64_t n_idx) +{ + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + return batch_idx * k * n + n_idx * n0 * k + k_idx * k0; + } else { + return batch_idx * k * n + k_idx * k0 * n + n_idx * n0; + } + } else { + if constexpr (transB) { + return batch_idx * RoundUp<16>(n) * RoundUp<32>(k) + k_idx * k0 * RoundUp<16>(n) + n_idx * n0 * CONST_32; + } else { + return batch_idx * RoundUp<16>(k) * RoundUp<32>(n) + n_idx * n0 * RoundUp<16>(k) + k_idx * k0 * CONST_32; + } + } +} + +template +__aicore__ __force_inline__ void PpMatmulW8a8::CopyTileA( + const AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, + const uint64_t m_actual, const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) +{ + if ((m == 1) || (m_actual == 1 && !transA)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, BLOCK_SIZE_16, 1, k_actual, k_round, k); + } else { + if constexpr (transA) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + m_actual, // dTileActual + m_round, // dTileCeil + m); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } + } +} + +template +__aicore__ __force_inline__ void PpMatmulW8a8::CopyTileB( + const AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, + const uint64_t k_actual, const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) +{ + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp<16>(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp<32>(k)); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp<16>(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp<32>(n)); // dVal + } + } +} + +template +__aicore__ __force_inline__ void PpMatmulW8a8::InitBuffer() +{ + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(SCALE_L1_LEN + BIAS_L1_LEN); + + // try load all A matrix + uint32_t a_l1_size = RoundUp(m) * RoundUp(k); + if (!load_all_Amat_flag) { + a_l1_size = RoundUp(m0 * k0); + if constexpr (transA || !transB) { + a_l1_size = RoundUp(RoundUp(m0) * k0); + } + } + + l1_base_b = l1_base_a[a_l1_size]; + bias_l1 = buf.template GetBuffer(0); + scale_l1 = buf.template GetBuffer(BIAS_L1_LEN); + scale_fb.InitBuffer(0, FB_BUFF_SIZE); + + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); + l0c_buf = buf.template GetBuffer(0); + return; +} + +template +__aicore__ __force_inline__ void PpMatmulW8a8::GetBaseBlockIdx( + uint64_t index, uint64_t &m_idx, uint64_t &n_idx) +{ + uint64_t in_batch_idx = index % (m_loop * n_loop); + if constexpr (swizzleDir == 0) { // Zn + uint64_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzle_cnt * tile_block_idx; + } + m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if ((tile_block_idx & 0b1) != 0) { + n_idx = n_loop - n_idx - 1; + } + } else { // Nz + uint64_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzle_cnt * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if ((tile_block_idx & 0b1) != 0) { + m_idx = m_loop - m_idx - 1; + } + } + return; +} + +template +__aicore__ __force_inline__ void +PpMatmulW8a8::PreloadDoubleWeight() +{ +#ifdef __DAV_C220_CUBE__ + if (core_idx < core_num) { + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(core_idx, m_idx, n_idx); + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + if (k_loop > 1) { + uint64_t shuffle_k = en_shuffle_k ? (core_idx + 1) % k_loop : 1; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + CopyTileB(l1_base_b[b0mat_pingpong_buffer_len], gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + } +#endif +} + +template +__aicore__ __force_inline__ void PpMatmulW8a8::Process() +{ + using LocalTensor = AscendC::LocalTensor; + if (core_idx >= core_num) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(MM1); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(MM2QUANT); + } + return; + } + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID7); + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { + uint64_t batch_idx = loop_idx / n_loop / m_loop; + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(loop_idx, m_idx, n_idx); + uint64_t offset_a; + uint64_t offset_b; + uint64_t offset_bias; + uint64_t offset_scalar; + uint64_t offset_a_next; + uint64_t offset_b_next; + uint64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + uint64_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t m_round = 0; + uint64_t n_round = 0; + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t m_round_16 = RoundUp(m_actual); + uint64_t m_round_32 = RoundUp(m_actual); + if constexpr (transA) { + m_round = m_round_32; + } else { + m_round = m_round_16; + } + if constexpr (transB) { + n_round = RoundUp(n_actual); + } else { + n_round = RoundUp(n_actual); + } + + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = 0; + k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / BLOCK_SIZE_32 * BLOCK_SIZE_32; + + offset_b = GetOffsetB(batch_idx, shuffle_k, n_idx); + offset_bias = batch_idx * n + n_idx * n0; + offset_scalar = batch_idx * n + n_idx * n0; + + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + if constexpr (withBias) { + WAIT_FLAG(MTE1, MTE2, EVENT_ID7); + gm_to_l1(bias_l1, // dst + gm_bias[offset_bias], // src + 1, BLOCK_SIZE_16, 1, n_actual, + n_round, n); + SET_FLAG(MTE2, MTE1, EVENT_ID6); + } + + // 3.13 Wait after Scalar + if (loop_idx == core_idx) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(MM1); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(MM2QUANT); + } + } + + WAIT_FLAG(MTE1, MTE2, event_id); + LocalTensor l1_buf_a = + load_all_Amat_flag ? l1_base_a : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + if (load_all_Amat_flag) { + if (loop_idx == core_idx) { + offset_a = GetOffsetA(batch_idx, m_idx, 0); + uint64_t k_actual_first = k; + uint64_t k_round_first = RoundUp(k_actual_first); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual_first, k_round_first); + } + } else { + offset_a = GetOffsetA(batch_idx, m_idx, shuffle_k); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + } + SET_FLAG(MTE2, MTE1, event_id); + + WAIT_FLAG(MTE1, MTE2, event_id + CONST_2); + // The first weight matrix block is loaded in advance. + if (loop_idx != core_idx) { + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id + CONST_2); + + WAIT_FLAG(FIX, MTE2, EVENT_ID0); + gm_to_l1(scale_l1, // dst + gm_descale[offset_scalar], // src + 1, BLOCK_SIZE_16, 1, n_actual, + n_round, n); + SET_FLAG(MTE2, FIX, EVENT_ID0); + WAIT_FLAG(MTE2, FIX, EVENT_ID0); + l1_to_fb(scale_fb, // dst + scale_l1, // src + 1, // nBurst + CeilDiv(n_actual * sizeof(ScaleDtype)), // lenBurst + 0, // srcGap + 0); // dstGap + // when move scalar form L1 to fifpipe end, can move A/B from gm to L1 + SET_FLAG(FIX, MTE2, EVENT_ID0); + + for (uint64_t k_idx = 0; k_idx < k_loop; k_idx++) { + shuffle_k = en_shuffle_k ? (k_idx + core_idx) % k_loop : k_idx; + uint32_t k_actual = (shuffle_k == (k_loop - 1)) ? (k - shuffle_k * k0) : k0; + uint32_t k_round = RoundUp(k_actual); + uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len; + + // --------- load whole A in l1a addr change ------------- + LocalTensor l1_buf_a = load_all_Amat_flag ? (l1_base_a[k_idx * m0 * k0 * sizeof(int8_t)]) + : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (k_idx < k_loop - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + k_idx + 1) % k_loop : k_idx + 1; + + offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, n_idx); + uint32_t k_actual_next = (shuffle_k_next == (k_loop - 1)) ? (k - shuffle_k_next * k0) : k0; + uint32_t k_round_next = RoundUp(k_actual_next); + + LocalTensor l1_buf_a_next = + load_all_Amat_flag ? l1_base_a : ((1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + if (!load_all_Amat_flag) { + offset_a_next = GetOffsetA(batch_idx, m_idx, shuffle_k_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + } + SET_FLAG(MTE2, MTE1, event_id_next); + + WAIT_FLAG(MTE1, MTE2, event_id_next + CONST_2); + if (loop_idx != core_idx || k_idx != 0) { // The second weight matrix is preloaded. + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id_next + CONST_2); + } + + for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { + uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len; + uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len; + + auto mte1_mad_ping_flag = 1 - k_part_idx % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor l0a_buf = l0a_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + AscendC::LocalTensor l0b_buf = l0b_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1 && !transA)) { + l1_to_l0_a( + l0a_buf, l1_buf_a[k_part_idx * k_part_len], + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + if constexpr (transA) { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[k_part_idx * k_part_len * BLOCK_SIZE_32], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + k_round / BLOCK_SIZE_16, // mSrcStride + 1, // kSrcStride + k0_round / BLOCK_SIZE_32, // mDstStride + 1); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / BLOCK_SIZE_16, // kSrcStride + k0_round / BLOCK_SIZE_32, // mDstStride + 1); // kDstStride + } + } + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + CONST_2); + } + if constexpr (transB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / BLOCK_SIZE_16, // kSrcStride + 1, // nDstStride + k0_round / BLOCK_SIZE_32); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[k_part_idx * k_part_len * BLOCK_SIZE_32], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / BLOCK_SIZE_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / BLOCK_SIZE_16); // kDstStride + } + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id + CONST_2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (k_idx == 0 && k_part_idx == 0); + bool sp_flag = (m != 1 && m_actual == 1 && transA); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + if (init_c) { + if constexpr (withBias) { + WAIT_FLAG(MTE2, MTE1, EVENT_ID6); + l1_to_bt( + bias_bt, // dst + bias_l1, // src + 0, // convControl + 1, // nBurst + CeilDiv(n_actual * sizeof(BiasDtype)), // lenBurst + 0, // srcGap + 0); // dstGap + SET_FLAG(MTE1, MTE2, EVENT_ID7); // bias ready, mte2 can begin move A/B or scale + SET_FLAG(MTE1, M, EVENT_ID7); // bias ready, mmad can begin + WAIT_FLAG(MTE1, M, EVENT_ID7); // wait move bias from L1 to BT + Mmad(l0c_buf, l0a_buf, l0b_buf, ((uint64_t)bias_bt), + sp_flag ? m_round_16 : m_actual, // m + n_actual, // n + k0_actual, // k + 0); // cmatrixInitVal + } else { + Mmad(l0c_buf, l0a_buf, l0b_buf, + sp_flag ? m_round_16 : m_actual, // m + n_actual, // n + k0_actual, // k + 1); // cmatrixInitVal + } + } else { + Mmad(l0c_buf, l0a_buf, l0b_buf, + sp_flag ? m_round_16 : m_actual, // m + n_actual, // n + k0_actual, // k + 0); // cmatrixInitVal + } + AscendC::PipeBarrier(); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + AscendC::PipeBarrier(); + SetFpc(scale_fb, false); + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // MSize + n_actual, // NSize + m_round_16, // srcStride + n); // dstStride_dst_D + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID7); +} +#endif + +template +class MLAOperation +{ + using qOutDtype = typename std::conditional_t; + using kNopeDtype = typename std::conditional_t; + +public: + __aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm) + { + blockIdx = AscendC::GetBlockIdx(); +#ifdef __DAV_C220_VEC__ + sub_block_idx = static_cast(GetSubBlockidx()); +#endif + vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx; + this->n = mlaParams_.n; + this->num_core_ = mlaParams_.rmsNumCore1; + this->num_col_1 = mlaParams_.rmsNumCol1; + this->num_col_2 = mlaParams_.rmsNumCol2; + this->num_row = mlaParams_.n; + this->epsilon_ = 1e-6; + this->mlaParams = mlaParams_; + } + + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, + GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, + GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, + GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, + GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR bias2Gm, GM_ADDR wukGm, + GM_ADDR descale1Gm, GM_ADDR descale2Gm, GM_ADDR gmCtkvScale, GM_ADDR gmQnopeScale, + GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm, + GM_ADDR s2Gm, GM_ADDR s3Gm) + { + s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm)); + wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); + bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); + descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t *>(descale1Gm)); + s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s3Gm)); + +#ifdef __DAV_C220_CUBE__ + mm_w8a8_1.Init(s1GmTensor, wdqkvGmTensor, bias1gmTensor, descale1gmTensor, s3GmTensor, mlaParams, 0); + mm_w8a8_1.PreloadDoubleWeight(); +#endif + hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm)); + gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm)); + quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm)); + quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); + + gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma2Gm)); + quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale2Gm)); + quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gmCtkvScale)); + quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm)); + gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma3Gm)); + sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(sin1Gm)); + cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cos1Gm)); + sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(sin2Gm)); + cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cos2Gm)); + keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ kNopeDtype *>(keycacheOutGm)); + keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(keycacheOutGm2)); + slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); + wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm)); + wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(wukGm)); + descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t *>(descale2Gm)); + s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s2Gm)); + qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ qOutDtype *>(qGm)); + qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2)); + bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); + + beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm)); + beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm)); +#ifdef __DAV_C220_CUBE__ + mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1); + if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + mm_ein_sum.Init(s2Gm, wukGm, s1Gm, mlaParams); + } else { + mm_ein_sum.Init(s2Gm, wukGm, qGm, mlaParams); + } +#endif + +#ifdef __DAV_C220_VEC__ + // rmsnormQuant + row_work = (num_row + num_core_ - 1) / num_core_; + row_work_ = 0; + uint32_t need_core = (num_row + row_work - 1) / row_work; + if (vectorBlockIdx < need_core - 1) { + row_work_ = row_work; + } else if (vectorBlockIdx == need_core - 1) { + row_work_ = num_row - (need_core - 1) * row_work; + } else { + row_work_ = 0; + } + this->splitN = mlaParams.perTaskNum; + Quant1.Init(gamma1GmTensor, beta1GmTensor, quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor, + s1GmTensor, 0, num_col_1, 0.0001395089285, + vectorBlockIdx * static_cast(row_work) * num_col_1, + vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); + + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, s3GmTensor, + s1GmTensor, SPLIT_SIZE_ONE, num_col_2, 0.000651041666, + vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + ropeFp16.RopeInit(s2GmTensor, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); + einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); + ubTensor = buf.GetBuffer(0); + ub8Tensor = buf.GetBuffer(0); + ub32Tensor = buf.GetBuffer(0); +#endif + } + + __aicore__ inline void ProcessCube(); + + __aicore__ inline void ProcessVector(); + +private: + constexpr static uint32_t C0_SIZE = 16; + constexpr static uint32_t I8_C0_SIZE = 32; + + template + __aicore__ inline void RmsNormAndRopeConvergence1( + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &sinTensor, const AscendC::LocalTensor &cosTensor, + const AscendC::LocalTensor &slotMappingTensor, const uint32_t sN, + const AscendC::LocalTensor &rmsNormTensor, const AscendC::LocalTensor &gammaFp32, + const AscendC::LocalTensor &ropeKTensor, const AscendC::LocalTensor &ropeKRevertTensor, + const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor, + AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) + { + int64_t slotMapGmOffset = vectorBlockIdx * row_work; + AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], + AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); + SET_FLAG(MTE2, V, EVENT_ID2); + WAIT_FLAG(MTE2, V, EVENT_ID2); + SET_FLAG(MTE2, S, EVENT_ID2); + WAIT_FLAG(MTE2, S, EVENT_ID2); + for (uint64_t loop = 0; loop < sN; ++loop) { + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); + if (slotValue == -1) { + continue; + } + AscendC::DataCopy(srcTensor, s3GmTensor[offset], SPLIT_SIZE_ONE); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + SET_FLAG(MTE2, V, EVENT_ID0); + // ND + uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + // NZ + uint32_t outer_idx = slotValue / 128; + uint32_t inner_idx = slotValue % 128; + SET_FLAG(S, MTE3, EVENT_ID0); + /* RmsNorm start */ + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], + SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(V, S, EVENT_ID1); + WAIT_FLAG(V, S, EVENT_ID1); + float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + SET_FLAG(S, V, EVENT_ID1); + WAIT_FLAG(S, V, EVENT_ID1); + AscendC::PipeBarrier(); + Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + // quant + Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + } else { + AscendC::PipeBarrier(); + if (std::is_same::value) { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + } else { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + } + } + /* RmsNorm end */ + // /* Rope K start */ + uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; + Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + revertOffset); + Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + revertOffset); + Duplicate(calTensor, static_cast(-1), revertOffset); + Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); + AscendC::PipeBarrier(); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + /* Rope K end */ + // reshapeAndcache + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (cacheMode == CACHE_MODE_KVCACHE) { + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + } else if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + // NZ + int64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + AscendC::DataCopyExtParams outExt; + // nope:int8 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); + outExt.srcStride = 0; + outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); + DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else if constexpr (cacheMode == CACHE_MODE_NZCACHE) { + uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + // nope:T1 nz + AscendC::DataCopyExtParams outExt; + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else { + // keycache1 + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + // keycache2 + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], + SPLIT_RMSNRORM_SIZE_TWO); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + } + +private: + uint32_t n; + uint32_t splitN; + uint32_t rotaryCoeff; + uint32_t blockIdx; + uint32_t sub_block_idx; + uint32_t vectorBlockIdx; + uint32_t blockOffset; + uint32_t perTaskNum; + uint32_t resTaskNum; + MlaTilingData mlaParams; + + // rmsnormQuant + uint32_t num_core_; + uint32_t num_col_1; + uint32_t num_col_2; + float epsilon_; + uint32_t num_row; + uint32_t quantMin_; + uint32_t row_work; + uint32_t row_work_; + + AsdopsBuffer buf; + AscendC::LocalTensor ubTensor; + AscendC::LocalTensor ub8Tensor; + AscendC::LocalTensor ub32Tensor; + + AscendC::GlobalTensor hiddenStateGmTensor; + + AscendC::GlobalTensor gamma1GmTensor; + AscendC::GlobalTensor quantScale1GmTensor; + AscendC::GlobalTensor quantOffset1GmTensor; + + AscendC::GlobalTensor wdqkvGmTensor; + AscendC::GlobalTensor gamma2GmTensor; + AscendC::GlobalTensor quantScale2GmTensor; + AscendC::GlobalTensor quantScale3GmTensor; + AscendC::GlobalTensor quantOffset2GmTensor; + AscendC::GlobalTensor gamma3GmTensor; + AscendC::GlobalTensor sin1GmTensor; + AscendC::GlobalTensor cos1GmTensor; + AscendC::GlobalTensor sin2GmTensor; + AscendC::GlobalTensor cos2GmTensor; + AscendC::GlobalTensor keycacheGmTensor1; + AscendC::GlobalTensor keycacheGmTensor2; + AscendC::GlobalTensor slotMappingGmTensor; + AscendC::GlobalTensor wuqGmTensor; + AscendC::GlobalTensor wukGmTensor; + + AscendC::GlobalTensor qGmTensor; + AscendC::GlobalTensor qGmTensor2; + AscendC::GlobalTensor s1GmTensor; + AscendC::GlobalTensor s2GmTensor; + AscendC::GlobalTensor s3GmTensor; + AscendC::GlobalTensor descale1gmTensor; + AscendC::GlobalTensor descale2gmTensor; + AscendC::GlobalTensor beta1GmTensor; + AscendC::GlobalTensor beta2GmTensor; + + AscendC::GlobalTensor bias1gmTensor; + AscendC::GlobalTensor bias2gmTensor; + +#ifdef __DAV_C220_CUBE__ + PpMatmulW8a8 mm_w8a8_1; + PpMatmulW8a8 mm_w8a8_2; + static constexpr uint64_t splitGapC = cacheMode == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; + PpMatmulEinSum mm_ein_sum; +#endif + +#ifdef __DAV_C220_VEC__ + Quant Quant1; + RmsNormQuant rmsNormQuant2; + RopeFp16 ropeFp16; + EinSumQuant einSumQuant; +#endif +}; + +template +__aicore__ inline void MLAOperation::ProcessCube() +{ +#ifdef __DAV_C220_CUBE__ + mm_w8a8_1.Process(); + FftsCrossCoreSync(RMSNORMQUANT2); + WaitFlagDev(RMSNORMQUANT2); + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(MM1QUANT); + + mm_w8a8_2.PreloadDoubleWeight(); + mm_w8a8_2.Process(); + FftsCrossCoreSync(MM2OUT); + mm_ein_sum.PreloadB(); + mm_ein_sum.Process(); + if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + FftsCrossCoreSync(EINSUMOUT); + WaitFlagDev(EINSUMOUT); + FftsCrossCoreSync(EINSUMQUANT); + } +#endif +} + +template +__aicore__ inline void MLAOperation::ProcessVector() +{ +#ifdef __DAV_C220_VEC__ + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(HIDDTEN_STATE * 2); + AscendC::LocalTensor beta_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); + AscendC::LocalTensor res1_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 32); + Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, + res3_tensor); + } + FftsCrossCoreSync(QUANT1); + WaitFlagDev(QUANT1); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(MM1); + + WaitFlagDev(MM1QUANT); + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor beta_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 32); + rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, + res1_tensor, res3_tensor); + } + FftsCrossCoreSync(MM2); + WaitFlagDev(MM2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(MM2QUANT); + + if (row_work_ != 0) { + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor sin_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + AscendC::LocalTensor cos_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + int32_t rms3_ub_offset = + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); + + int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + + 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4; + AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); + + AscendC::LocalTensor tmpfp16; + AscendC::LocalTensor int8OutTensor; + float scale3 = 0; + if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + // quantScale3 + AscendC::LocalTensor quantScaleTensor = buf.GetBuffer(rms3_ub_offset); + AscendC::LocalTensor floatQuantScaleTensor = + buf.GetBuffer(rms3_ub_offset + 32); + // int8out + tmpfp16 = buf.GetBuffer(rms3_ub_offset + + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + int8OutTensor = buf.GetBuffer(out_ub_offset); + AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); + } + + RmsNormAndRopeConvergence1( + input_tensor, // n * 576 + gamma_tensor, // gamma + sin_tensor, // sin + cos_tensor, // cons + slotMapping_tensor, // slotMapping + row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + + SPLIT_RMSNRORM_SIZE_TWO], + temp_tensor, tmpfp16, int8OutTensor, scale3); + } + WaitFlagDev(BMM3SPLIT); + ropeFp16.Process(); + + if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + WaitFlagDev(EINSUMQUANT); + einSumQuant.Process(); + PIPE_BARRIER(ALL); + } +#endif +} + +} // namespace MLAPO_FP16 diff --git a/csrc/ops.h b/csrc/ops.h index 4773992230..6364005096 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -124,4 +124,40 @@ namespace vllm_ascend { uint32_t output_hidden_dim, uint32_t slice_offset, uint32_t output_full_dim); + + extern void mla_preprocess_impl( + void* stream, + void* hidden_state, + void* gamma1, + void* beta1, + void* quant_scale1, + void* quant_offset1, + void* wdqkv, + void* bias1, + void* gamma2, + void* beta2, + void* quant_scale2, + void* quant_offset2, + void* gamma3, + void* sin1, + void* cos1, + void* sin2, + void* cos2, + void* keycache, + void* slot_mapping, + void* wuq, + void* bias2, + void* wuk, + void* descale1, + void* descale2, + void* ctkv_scale, + void* qnope_scale, + void* q, + void* keycache_out, + void* q2, + void* keycache_out2, + void* workspace, + void* tiling, + const uint32_t block_dim + ); } diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 5dd6988a9d..74614e583b 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -23,6 +23,7 @@ #include "acl/acl.h" #include "ops.h" #include "utils.h" +#include "mla_preprocess/op_host/mla_preprocess.h" namespace vllm_ascend { @@ -106,6 +107,83 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T return {query_dst, key_dst}; } +std::tuple mla_preprocess( + const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv, + const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, + const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, + const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, + const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, + const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, + const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, + c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, + at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) +{ + at::Tensor CtkvScale = + ctkv_scale.has_value() + ? ctkv_scale.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor QnopeScale = + q_nope_scale.has_value() + ? q_nope_scale.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + + auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( + hiddenState, + wuk, + cache_mode, + quant_mode + ); + + void *hidden_state_ptr = hiddenState.data_ptr(); + void *gamma0_ptr = gamma0.data_ptr(); + void *beta0_ptr = beta0.data_ptr(); + void *quant_scale0_ptr = quant_scale0.data_ptr(); + void *quant_offset0_ptr = quant_offset0.data_ptr(); + void *wdqkv_ptr = wdqkv.data_ptr(); + void *bias0_ptr = bias0.data_ptr(); + void *gamma1_ptr = gamma1.data_ptr(); + void *beta1_ptr = beta1.data_ptr(); + void *quant_scale1_ptr = quant_scale1.data_ptr(); + void *quant_offset1_ptr = quant_offset1.data_ptr(); + void *gamma2_ptr = gamma2.data_ptr(); + void *sin_ptr = sin.data_ptr(); + void *cos_ptr = cos.data_ptr(); + void *kv_cache_ptr = kv_cache.data_ptr(); + void *slotmapping_ptr = slotmapping.data_ptr(); + void *wuq_ptr = wuq.data_ptr(); + void *bias1_ptr = bias1.data_ptr(); + void *wuk_ptr = wuk.data_ptr(); + void *descale0_ptr = descale0.data_ptr(); + void *descale1_ptr = descale1.data_ptr(); + void *ctkv_scale_ptr = CtkvScale.data_ptr(); + void *qnope_scale_ptr = QnopeScale.data_ptr(); + void *q_out0_ptr = q_out0.data_ptr(); + void *kv_cache_out0_ptr = kv_cache_out0.data_ptr(); + void *q_out1_ptr = q_out1.data_ptr(); + void *kv_cache_out1_ptr = kv_cache_out1.data_ptr(); + void *workspace_ptr = workspace_tensor.data_ptr(); + void *tiling_ptr = tiling.data_ptr(); + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("mla_preprocess"); + + cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, + kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, + tiling_ptr, block_dim]() -> int { + mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, + gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, + kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, + tiling_ptr, block_dim); + return 0; + }); + cmd.Run(); + return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1); +} + std::tuple get_masked_input_and_mask( at::Tensor &input, const int64_t org_vocab_start_index, @@ -422,4 +500,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y," " int slice_offset, int slice_size) -> Tensor"); ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand); + + ops.def( + "mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv," + " Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1," + " Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache," + " Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," + " Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1," + " Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode," + " str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1," + " Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0," + " Tensor q_out1, Tensor kv_cache_out1)" + ); + ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 4101ee71e0..bf7ed01904 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -81,6 +81,41 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ return y_out; } +std::tuple mla_preprocess( + const at::Tensor &hiddenState, + const at::Tensor &gamma0, + const at::Tensor &beta0, + const at::Tensor &wdqkv, + const at::Tensor &descale0, + const at::Tensor &gamma1, + const at::Tensor &beta1, + const at::Tensor &wuq, + const at::Tensor &descale1, + const at::Tensor &gamma2, + const at::Tensor &cos, + const at::Tensor &sin, + const at::Tensor &wuk, + const at::Tensor &kv_cache, + const at::Tensor &kv_cache_rope, + const at::Tensor &slotmapping, + const at::Tensor &quant_scale0, + const at::Tensor &quant_offset0, + const at::Tensor &bias0, + const at::Tensor &quant_scale1, + const at::Tensor &quant_offset1, + const at::Tensor &bias1, + const c10::optional &ctkv_scale, + const c10::optional &q_nope_scale, + c10::optional cache_mode, + c10::optional quant_mode, + at::Tensor &q_out0, + at::Tensor &kv_cache_out0, + at::Tensor &q_out1, + at::Tensor &kv_cache_out1) +{ + return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; +} + } // namespace meta } // namespace vllm_ascend @@ -97,6 +132,7 @@ namespace { ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta); // Sgmv expand ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); - + // MLA preprocess + ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); } }