diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 631ea4e3c2b..4a5591b0638 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;lightning_attention_decode;lightning_attention_prefill" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -68,6 +68,8 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "causal_conv1d" "moe_grouped_matmul" "lightning_indexer_quant" + "lightning_attention_decode" + "lightning_attention_prefill" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/lightning_attention_decode/lightning_attention_decode_torch_adpt.h b/csrc/lightning_attention_decode/lightning_attention_decode_torch_adpt.h new file mode 100644 index 00000000000..a1fc0a4a279 --- /dev/null +++ b/csrc/lightning_attention_decode/lightning_attention_decode_torch_adpt.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H +#define LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H + +namespace vllm_ascend { + at::Tensor npu_lightning_attention_decode( + const at::Tensor &query, + const at::Tensor &key, + const at::Tensor &value, + const at::Tensor &kv_caches_ref, + const at::Tensor &slope_rate, + const at::Tensor &slot_ids) + { + auto output_size_0 = {query.size(0), query.size(1) * query.size(3)}; + auto output_dtype_0 = query.scalar_type(); + at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0)); + EXEC_NPU_CMD( + aclnnLightningAttentionDecode, + query, + key, + value, + slope_rate, + kv_caches_ref, + slot_ids, + "BNSD", + attention_out); + return attention_out; + } +} + +#endif diff --git a/csrc/lightning_attention_decode/op_host/CMakeLists.txt b/csrc/lightning_attention_decode/op_host/CMakeLists.txt new file mode 100644 index 00000000000..19921dc6666 --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME LightningAttentionDecode + OPTIONS --cce-auto-sync + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnInner PRIVATE + lightning_attention_decode_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_lightning_attention.cpp +) + +target_sources(optiling PRIVATE + lightning_attention_decode_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + lightning_attention_decode_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + lightning_attention_decode_proto.cpp +) + +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/aclnn_lightning_attention_decode.h + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention.cpp b/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention.cpp new file mode 100644 index 00000000000..7708106ac60 --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention.cpp @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/types.h" +#include "aclnn_lightning_attention_decode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +extern aclnnStatus aclnnInnerLightningAttentionDecodeGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + aclTensor *kvCachesRef, + const aclTensor *slotIds, + char *inputLayoutOptional, + const aclTensor *attentionOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +extern aclnnStatus aclnnInnerLightningAttentionDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + aclTensor *kvCachesRef, + const aclTensor *slotIds, + char *inputLayoutOptional, + const aclTensor *attentionOut, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + aclnnStatus ret = aclnnInnerLightningAttentionDecodeGetWorkspaceSize( + query, key, value, slopeRate, kvCachesRef, slotIds, + inputLayoutOptional, attentionOut, workspaceSize, executor); + return ret; +} + +aclnnStatus aclnnLightningAttentionDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + aclnnStatus ret = aclnnInnerLightningAttentionDecode(workspace, workspaceSize, executor, stream); + return ret; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention_decode.h b/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention_decode.h new file mode 100644 index 00000000000..a9321a60697 --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/aclnn_lightning_attention_decode.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ACLNN_LIGHTNING_ATTENTION_DECODE_H_ +#define ACLNN_LIGHTNING_ATTENTION_DECODE_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* function: aclnnLightningAttentionDecodeGetWorkspaceSize + * parameters : + * query : required + * key : required + * value : required + * slopeRate : required + * kvCachesRef : required + * slotIds : required + * inputLayoutOptional : optional + * attentionOut : required + * kvCachesRef : required + * workspaceSize : size of workspace(output). + * executor : executor context(output). + */ +__attribute__((visibility("default"))) +aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + aclTensor *kvCachesRef, + const aclTensor *slotIds, + char *inputLayoutOptional, + const aclTensor *attentionOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +/* function: aclnnLightningAttentionDecode + * parameters : + * workspace : workspace memory addr(input). + * workspaceSize : size of workspace(input). + * executor : executor context(input). + * stream : acl stream. + */ +__attribute__((visibility("default"))) +aclnnStatus aclnnLightningAttentionDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/lightning_attention_decode/op_host/lightning_attention_decode_def.cpp b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_def.cpp new file mode 100644 index 00000000000..80b0f8ee46c --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_def.cpp @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "register/op_def_registry.h" + +namespace ops { + +class LightningAttentionDecode : public OpDef { +public: + explicit LightningAttentionDecode(const char* name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("slope_rate") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("kv_caches") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("slot_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("attention") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("kv_caches") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("input_layout").AttrType(OPTIONAL).String("BNSD"); + + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + } +}; + +OP_ADD(LightningAttentionDecode); +} diff --git a/csrc/lightning_attention_decode/op_host/lightning_attention_decode_proto.cpp b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_proto.cpp new file mode 100644 index 00000000000..2238ff4d0f4 --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_proto.cpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "register/op_def_registry.h" + +namespace ops { + +static constexpr size_t INDEX_IN_Q = 0; +static constexpr size_t INDEX_IN_K = 1; +static constexpr size_t INDEX_IN_V = 2; +static constexpr size_t INDEX_IN_SLP_RATE = 3; +static constexpr size_t INDEX_IN_KV_HIS = 4; +static constexpr size_t INDEX_IN_SLT_IDS = 5; +static constexpr size_t DIM_2 = 2; +static constexpr size_t DIM_3 = 3; +static constexpr size_t INDEX_OUT_ATTN = 0; +static constexpr size_t INDEX_OUT_KV_CACHES = 1; + +static ge::graphStatus InferShapeLightningAttentionDecode(gert::InferShapeContext* context) +{ + const gert::Shape* q_shape = context->GetInputShape(INDEX_IN_Q); + gert::Shape* attn_out_shape = context->GetOutputShape(INDEX_OUT_ATTN); + gert::Shape* kv_caches_shape = context->GetOutputShape(INDEX_OUT_KV_CACHES); + *attn_out_shape = *q_shape; + + kv_caches_shape->SetDimNum(q_shape->GetDimNum()); + kv_caches_shape->SetDim(0, q_shape->GetDim(0)); + kv_caches_shape->SetDim(1, q_shape->GetDim(1)); + kv_caches_shape->SetDim(DIM_2, q_shape->GetDim(DIM_3)); + kv_caches_shape->SetDim(DIM_3, q_shape->GetDim(DIM_3)); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeLightningAttentionDecode(gert::InferDataTypeContext *context) +{ + const auto inputDataType = context->GetInputDataType(INDEX_IN_Q); + context->SetOutputDataType(INDEX_OUT_ATTN, inputDataType); + context->SetOutputDataType(INDEX_OUT_KV_CACHES, inputDataType); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(LightningAttentionDecode) + .InferShape(InferShapeLightningAttentionDecode) + .InferDataType(InferDataTypeLightningAttentionDecode); + +} diff --git a/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.cpp b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.cpp new file mode 100644 index 00000000000..d5b6d432e8c --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.cpp @@ -0,0 +1,160 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "lightning_attention_decode_tiling.h" +#include "register/op_impl_registry.h" + +namespace optiling { + +static constexpr uint32_t MAX_BASE_M = 128; +static constexpr size_t DIM_3 = 3; + +bool LightningAttentionDecodeTiling::IsCapable() +{ + return true; +} + +ge::graphStatus LightningAttentionDecodeTiling::GetPlatformInfo() +{ + aicNum_ = ascendcPlatform_->GetCoreNumAic(); + aivNum_ = ascendcPlatform_->GetCoreNumAiv(); + actualUsedAivNum_ = aivNum_; + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::UB, aicoreParams_.ubSize); + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::L1, aicoreParams_.l1Size); + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, aicoreParams_.l0cSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionDecodeTiling::GetShapeAttrsInfo() +{ + if (!AnalyzeDType()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionDecodeTiling::DoOpTiling() +{ + auto qShape = context_->GetInputShape(0)->GetStorageShape(); + auto kvCacheShape = context_->GetInputShape(4)->GetStorageShape(); + // set base params + tilingData_.laBaseParams.set_batchSize(qShape.GetDim(0)); + tilingData_.laBaseParams.set_kvCacheBatchSize(kvCacheShape.GetDim(0)); + tilingData_.laBaseParams.set_headNum(qShape.GetDim(1)); + headDimBlock_ = qShape.GetDim(DIM_3); + tilingData_.laBaseParams.set_headDim(headDimBlock_); + + taskNum_ = tilingData_.laBaseParams.get_batchSize() * tilingData_.laBaseParams.get_headNum(); + if (taskNum_ < actualUsedAivNum_) { + actualUsedAivNum_ = taskNum_; + } + tilingData_.laBaseParams.set_actualUsedAivNum(actualUsedAivNum_); + tilingData_.laBaseParams.set_taskNum(taskNum_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionDecodeTiling::DoLibApiTiling() +{ + if (UseMatmul() && !SetMatmulTiling()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +uint64_t LightningAttentionDecodeTiling::GetTilingKey() const +{ + return 0; +} + +ge::graphStatus LightningAttentionDecodeTiling::GetWorkspaceSize() +{ + workspaceSize_ = 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionDecodeTiling::PostTiling() +{ + if (UseMatmul()) { + auto blockDim = CalcTschBlockDim(actualUsedAivNum_, aicNum_, aivNum_); + context_->SetBlockDim(blockDim); + } else { + context_->SetBlockDim(actualUsedAivNum_); + } + size_t *currentWorkspace = context_->GetWorkspaceSizes(1); + currentWorkspace[0] = workspaceSize_ + ascendcPlatform_->GetLibApiWorkSpaceSize(); + tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), + context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); + return ge::GRAPH_SUCCESS; +} + +bool LightningAttentionDecodeTiling::UseMatmul() const { + return mm1InDType_ == matmul_tiling::DataType::DT_FLOAT; +} + +bool LightningAttentionDecodeTiling::AnalyzeDType() +{ + inputDType_ = context_->GetInputDesc(0)->GetDataType(); + switch (inputDType_) { + case ge::DT_FLOAT16: + mm1InDType_ = matmul_tiling::DataType::DT_FLOAT16; + mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT16; + break; + case ge::DT_BF16: + mm1InDType_ = matmul_tiling::DataType::DT_BF16; + mm1OutDType_ = matmul_tiling::DataType::DT_BF16; + break; + case ge::DT_FLOAT: + mm1InDType_ = matmul_tiling::DataType::DT_FLOAT; + mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; + break; + default: + return false; + } + return true; +} + +bool LightningAttentionDecodeTiling::SetMatmulTiling() +{ + matmul_tiling::MatmulApiTiling mm1(*ascendcPlatform_); + mm1.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm1InDType_, false); + mm1.SetBType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm1InDType_, false); + mm1.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm1OutDType_); + mm1.SetShape(1, headDimBlock_, headDimBlock_); + mm1.SetOrgShape(1, headDimBlock_, headDimBlock_, headDimBlock_); + mm1.SetBias(false); + if (mm1.SetBufferSpace(aicoreParams_.l1Size, aicoreParams_.l0cSize) != 0) { + return false; + } + if (mm1.SetFixSplit(-1, std::min(headDimBlock_, MAX_BASE_M)) != 0) { + return false; + } + if (mm1.GetTiling(tilingData_.mm1TilingData) != 0) { + return false; + } + return true; +} + +ASCENDC_EXTERN_C ge::graphStatus TilingLightningAttentionDecode(gert::TilingContext* context) +{ + LightningAttentionDecodeTiling tiling(context); + return tiling.DoTiling(); +} + +ASCENDC_EXTERN_C ge::graphStatus TilingPrepareForLightningAttentionDecode(gert::TilingParseContext *context) +{ + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(LightningAttentionDecode) + .Tiling(TilingLightningAttentionDecode) + .TilingParse(TilingPrepareForLightningAttentionDecode); + +} diff --git a/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.h b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.h new file mode 100644 index 00000000000..bcc8d8bbcf9 --- /dev/null +++ b/csrc/lightning_attention_decode/op_host/lightning_attention_decode_tiling.h @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LIGHTNING_ATTENTION_DECODE_TILING_H +#define LIGHTNING_ATTENTION_DECODE_TILING_H + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "tiling/tiling_base.h" + +namespace optiling { + +BEGIN_TILING_DATA_DEF(LightningAttentionDecodeBaseParams) + TILING_DATA_FIELD_DEF(uint32_t, batchSize); + TILING_DATA_FIELD_DEF(uint32_t, kvCacheBatchSize); + TILING_DATA_FIELD_DEF(uint32_t, headNum); + TILING_DATA_FIELD_DEF(uint32_t, headDim); + TILING_DATA_FIELD_DEF(uint32_t, actualUsedAivNum); + TILING_DATA_FIELD_DEF(uint32_t, taskNum); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(LightningAttentionDecodeBaseParamsOp, LightningAttentionDecodeBaseParams) + +BEGIN_TILING_DATA_DEF(LightningAttentionDecodeTilingData) + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm1TilingData); + TILING_DATA_FIELD_DEF_STRUCT(LightningAttentionDecodeBaseParams, laBaseParams); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(LightningAttentionDecode, LightningAttentionDecodeTilingData) + +struct LightningAttentionDecodeCompileInfo {}; + +class LightningAttentionDecodeTiling : public TilingBaseClass { +public: + explicit LightningAttentionDecodeTiling(gert::TilingContext *context) + : TilingBaseClass(context) + { + ascendcPlatform_.reset(new platform_ascendc::PlatformAscendC(context->GetPlatformInfo())); + } +protected: + bool IsCapable() override; + + ge::graphStatus GetPlatformInfo() override; + + ge::graphStatus GetShapeAttrsInfo() override; + + ge::graphStatus DoOpTiling() override; + + ge::graphStatus DoLibApiTiling() override; + + uint64_t GetTilingKey() const override; + + ge::graphStatus GetWorkspaceSize() override; + + ge::graphStatus PostTiling() override; + +private: + bool UseMatmul() const; + bool AnalyzeDType(); + bool SetMatmulTiling(); + +private: + LightningAttentionDecodeTilingData tilingData_; + + ge::DataType inputDType_; + uint32_t aicNum_; + uint32_t aivNum_; + uint32_t actualUsedAivNum_; + uint32_t taskNum_; + uint32_t headDimBlock_; + + matmul_tiling::DataType mm1InDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; +}; + +} + +#endif // LIGHTNING_ATTENTION_DECODE_TILING_H diff --git a/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.cpp b/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.cpp new file mode 100644 index 00000000000..02d018fd51c --- /dev/null +++ b/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "kernel_operator.h" +#include "lightning_attention_decode.h" + +using namespace LightningAttention; + +#define COPY_TILING_DATA(tiling) \ + GET_TILING_DATA(tilingDataIn, tiling); \ + const LightningAttentionDecodeTilingData *__restrict tilingData = &tilingDataIn + +extern "C" __global__ __aicore__ void lightning_attention_decode( + GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, GM_ADDR kv_caches_ref_in, GM_ADDR slot_ids, + GM_ADDR attention_out, GM_ADDR kv_caches_ref_out, GM_ADDR workspace, GM_ADDR tiling) +{ + AscendC::TPipe pipe; + COPY_TILING_DATA(tiling); +#if (ORIG_DTYPE_QUERY == DT_FLOAT16) + LightningAttentionDecode op; + op.Init(query, key, value, slope_rate, kv_caches_ref_in, slot_ids, attention_out, kv_caches_ref_out, + workspace, tilingData, &pipe); + op.Process(); +#elif (ORIG_DTYPE_QUERY == DT_BF16) + LightningAttentionDecode op; + op.Init(query, key, value, slope_rate, kv_caches_ref_in, slot_ids, attention_out, kv_caches_ref_out, + workspace, tilingData, &pipe); + op.Process(); +#elif (ORIG_DTYPE_QUERY == DT_FLOAT) + LightningAttentionDecode op; + const TCubeTiling *__restrict mm1tiling = &(tilingData->mm1TilingData); + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), op.mm1, mm1tiling); + op.Init(query, key, value, slope_rate, kv_caches_ref_in, slot_ids, attention_out, kv_caches_ref_out, + workspace, tilingData, &pipe); + op.Process(); +#endif +} diff --git a/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.h b/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.h new file mode 100644 index 00000000000..a39a493ac42 --- /dev/null +++ b/csrc/lightning_attention_decode/op_kernel/lightning_attention_decode.h @@ -0,0 +1,318 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#ifndef LIGHTNING_ATTENTION_DECODE_H +#define LIGHTNING_ATTENTION_DECODE_H + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +using namespace matmul; + +namespace LightningAttention { + +template +class LightningAttentionDecode { +public: + __aicore__ inline LightningAttentionDecode() {} + __aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, + GM_ADDR kv_history, GM_ADDR slot_ids, GM_ADDR attention_out, GM_ADDR kv_cache_out, + GM_ADDR workspace, const LightningAttentionDecodeTilingData *__restrict tiling, + AscendC::TPipe *pipe); + __aicore__ inline void Process(); + +public: + // define matmul object for matmul(Q, KV) + using a1Type = MatmulType; + using b1Type = MatmulType; + using c1Type = MatmulType; + using bias1Type = MatmulType; + Matmul mm1; + +private: + __aicore__ inline void GenerateDecay(); + __aicore__ inline void ComputeAttention(uint32_t offset); + __aicore__ inline void UpdateKVCache(uint32_t kvCacheOffset, uint32_t offset, uint32_t headIdx); + __aicore__ inline void SaveKVCache(uint32_t kvCacheOffset); + __aicore__ inline void WaitKVCacheSaved(); + +private: + AscendC::GlobalTensor queryGM_; + AscendC::GlobalTensor keyGM_; + AscendC::GlobalTensor valueGM_; + AscendC::GlobalTensor slopeRateGM_; + AscendC::GlobalTensor slotIdsGM_; + AscendC::GlobalTensor attentionOutGM_; + AscendC::GlobalTensor kvCacheHistoryGM_; + AscendC::GlobalTensor kvCacheOutGM_; + + uint32_t currentCoreId_; + const LightningAttentionDecodeTilingData *__restrict tiling_; + uint32_t batchSize_; + uint32_t kvCacheBatchSize_; + uint32_t headNum_; + uint32_t headNumPad_; + uint32_t headDim_; + uint32_t actualUsedAivNum_; + uint32_t taskNum_; + uint32_t eleCountPerKVCache_; + + AscendC::TBuf kvCacheBuf_; + AscendC::TBuf decayBuf_; + AscendC::TBuf kBuf_; + AscendC::TBuf vBuf_; + AscendC::TQue qInQue_; + + AscendC::TQue attentionOutQue_; + AscendC::TBuf broadCastBuf_; + AscendC::TBuf kvFp32Buf_; + + AscendC::TBuf kvCacheFp32Buf_; + AscendC::TBuf decayFp32Buf_; + AscendC::TBuf qFp32Buf_; + AscendC::TBuf kFp32Buf_; + AscendC::TBuf vFp32Buf_; + AscendC::TBuf attentionFp32Buf_; +}; + +template +__aicore__ inline void LightningAttentionDecode::Init( + GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, GM_ADDR kv_history, GM_ADDR slot_ids, + GM_ADDR attention_out, GM_ADDR kv_cache_out, GM_ADDR workspace, + const LightningAttentionDecodeTilingData *__restrict tiling, AscendC::TPipe *pipe) +{ + currentCoreId_ = GetBlockIdx(); + tiling_ = tiling; + batchSize_ = tiling->laBaseParams.batchSize; + kvCacheBatchSize_ = tiling->laBaseParams.kvCacheBatchSize; + headNum_ = tiling->laBaseParams.headNum; + headNumPad_ = headNum_ < 16 ? 16 : headNum_; + headDim_ = tiling->laBaseParams.headDim; + actualUsedAivNum_ = tiling->laBaseParams.actualUsedAivNum; + taskNum_ = tiling->laBaseParams.taskNum; + eleCountPerKVCache_ = headDim_ * headDim_; + + queryGM_.SetGlobalBuffer((__gm__ T*)query); + keyGM_.SetGlobalBuffer((__gm__ T*)key); + valueGM_.SetGlobalBuffer((__gm__ T*)value); + slopeRateGM_.SetGlobalBuffer((__gm__ T*)slope_rate); + slotIdsGM_.SetGlobalBuffer((__gm__ int32_t*)slot_ids); + attentionOutGM_.SetGlobalBuffer((__gm__ T*)attention_out); + + kvCacheHistoryGM_.SetGlobalBuffer((__gm__ T*)kv_history); + kvCacheOutGM_.SetGlobalBuffer((__gm__ T*)kv_cache_out); + + auto maxBufSize = 128 * 128; + + pipe->InitBuffer(kvCacheBuf_, sizeof(T) * maxBufSize); // 32k for half, 64k for fp32 + pipe->InitBuffer(decayBuf_, sizeof(T) * headNumPad_); // maximum headNum is 64, 0.125k + pipe->InitBuffer(kBuf_, sizeof(T) * headDim_); // 0.25k + pipe->InitBuffer(vBuf_, sizeof(T) * headDim_); // 0.25k + pipe->InitBuffer(qInQue_, 1, sizeof(T) * headDim_); // 0.25k + pipe->InitBuffer(kvFp32Buf_, sizeof(float) * maxBufSize); // 64k + pipe->InitBuffer(broadCastBuf_, 4096); // reserved for broadcast, 4k + pipe->InitBuffer(attentionOutQue_, 1, sizeof(T) * headDim_); // 0.25k + + if constexpr (!IsSameType::value) { + pipe->InitBuffer(kvCacheFp32Buf_, sizeof(float) * maxBufSize); // 64k + pipe->InitBuffer(decayFp32Buf_, sizeof(float) * headNumPad_); // 0.25k + pipe->InitBuffer(qFp32Buf_, sizeof(float) * headDim_); // 0.5k + pipe->InitBuffer(kFp32Buf_, sizeof(float) * headDim_); // 0.5k + pipe->InitBuffer(vFp32Buf_, sizeof(float) * headDim_); // 0.5k + pipe->InitBuffer(attentionFp32Buf_, sizeof(float) * headDim_); // 0.5k + } +} + +template +__aicore__ inline void LightningAttentionDecode::Process() +{ + uint16_t absoluteHeadIdx = currentCoreId_; + uint32_t offset = absoluteHeadIdx * headDim_; + uint32_t offsetStep = actualUsedAivNum_ * headDim_; + uint32_t kvCacheOffsetPerBatch = eleCountPerKVCache_ * headNum_; + + GenerateDecay(); + + bool isFirstLoop = true; + for (uint32_t relativeHeadIdx, batchId, kvCacheSlotId, kvCacheOffset; absoluteHeadIdx < taskNum_; + absoluteHeadIdx += actualUsedAivNum_, offset += offsetStep) { + batchId = absoluteHeadIdx / headNum_; + kvCacheSlotId = slotIdsGM_.GetValue(batchId); + if (kvCacheSlotId < 0 || kvCacheSlotId >= kvCacheBatchSize_) { + continue; + } + relativeHeadIdx = absoluteHeadIdx % headNum_; + kvCacheOffset = kvCacheSlotId * kvCacheOffsetPerBatch + relativeHeadIdx * eleCountPerKVCache_; + if (isFirstLoop) { + isFirstLoop = false; + } else { + WaitKVCacheSaved(); + } + UpdateKVCache(kvCacheOffset, offset, relativeHeadIdx); + ComputeAttention(offset); + SaveKVCache(kvCacheOffset); + } +} + +template +__aicore__ inline void LightningAttentionDecode::GenerateDecay() +{ + auto decayTTensor = decayBuf_.Get(); + // Copy in + AscendC::DataCopy(decayTTensor, slopeRateGM_, headNumPad_); + int32_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + // Compute + if constexpr (IsSameType::value) { + AscendC::Muls(decayTTensor, decayTTensor, (float)-1.0, headNumPad_); + AscendC::PipeBarrier(); + AscendC::Exp(decayTTensor, decayTTensor, headNumPad_); + } else { + auto decayFp32Tensor = decayFp32Buf_.Get(); + AscendC::Cast(decayFp32Tensor, decayTTensor, RoundMode::CAST_NONE, headNumPad_); + AscendC::PipeBarrier(); + AscendC::Muls(decayFp32Tensor, decayFp32Tensor, (float)-1.0, headNumPad_); + AscendC::PipeBarrier(); + AscendC::Exp(decayFp32Tensor, decayFp32Tensor, headNumPad_); + } +} + + +template +__aicore__ inline void LightningAttentionDecode::ComputeAttention(uint32_t offset) +{ + // calculate O = matmul(Q, KV) + if constexpr (IsSameType::value) { + auto kvCacheTensor = kvCacheBuf_.Get(); + mm1.SetTensorA(queryGM_[offset]); + mm1.SetTensorB(kvCacheTensor); + mm1.template IterateAll(attentionOutGM_[offset]); + mm1.End(); + } else { + auto qTensor = qInQue_.AllocTensor(); + auto qBroadCastTensor = kvFp32Buf_.Get(); + auto broadCastTensor = broadCastBuf_.Get(); + const uint32_t dstShape[2] = {headDim_, headDim_}; + const uint32_t srcShape[2] = {headDim_, 1}; + uint32_t eleCount = 64; + + AscendC::DataCopy(qTensor, queryGM_[offset], headDim_); + qInQue_.EnQue(qTensor); + + qTensor = qInQue_.DeQue(); + auto kvCacheFp32Tensor = kvCacheFp32Buf_.Get(); + auto qFp32Tensor = qFp32Buf_.Get(); + auto attentionFp32Tensor = attentionFp32Buf_.Get(); + AscendC::Cast(qFp32Tensor, qTensor, RoundMode::CAST_NONE, headDim_); + qInQue_.FreeTensor(qTensor); + AscendC::BroadCast(qBroadCastTensor, qFp32Tensor, dstShape, srcShape, broadCastTensor); + AscendC::Duplicate(attentionFp32Tensor, 0.0f, headDim_); + AscendC::PipeBarrier(); + + AscendC::MulAddDst(attentionFp32Tensor, qBroadCastTensor, kvCacheFp32Tensor, eleCount, headDim_, + {1, 1, 1, 0, 16, 16}); + AscendC::MulAddDst(attentionFp32Tensor[eleCount], qBroadCastTensor[eleCount], kvCacheFp32Tensor[eleCount], + eleCount, headDim_, {1, 1, 1, 0, 16, 16}); + AscendC::PipeBarrier(); + + auto attentionTensor = attentionOutQue_.AllocTensor(); + AscendC::Cast(attentionTensor, attentionFp32Tensor, RoundMode::CAST_ROUND, headDim_); + attentionOutQue_.EnQue(attentionTensor); + + attentionTensor = attentionOutQue_.DeQue(); + AscendC::DataCopy(attentionOutGM_[offset], attentionTensor, headDim_); + attentionOutQue_.FreeTensor(attentionTensor); + } +} + +template +__aicore__ inline void LightningAttentionDecode::UpdateKVCache(uint32_t kvCacheOffset, uint32_t offset, + uint32_t headIndex) +{ + uint64_t mask = 64; + auto kvCacheTensor = kvCacheBuf_.Get(); + auto kTensor = kBuf_.Get(); + auto vTensor = vBuf_.Get(); + auto kvFp32Tensor = kvFp32Buf_.Get(); + auto broadCastTensor = broadCastBuf_.Get(); + const uint32_t dstShape[2] = {headDim_, headDim_}; + const uint32_t srcShape[2] = {headDim_, 1}; + AscendC::DataCopy(kvCacheTensor, kvCacheHistoryGM_[kvCacheOffset], eleCountPerKVCache_); + AscendC::DataCopy(kTensor, keyGM_[offset], headDim_); + AscendC::DataCopy(vTensor, valueGM_[offset], headDim_); + + int32_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + if constexpr (IsSameType::value) { + float decayLambda = decayBuf_.Get().GetValue(headIndex); + // multiply with decay + AscendC::Muls(kvCacheTensor, kvCacheTensor, decayLambda, eleCountPerKVCache_); + // KV_cur = Ki * Vi + AscendC::PipeBarrier(); + AscendC::BroadCast(kvFp32Tensor, kTensor, dstShape, srcShape, broadCastTensor); + AscendC::PipeBarrier(); + AscendC::Mul(kvFp32Tensor, kvFp32Tensor, vTensor, mask, 128, {1, 1, 1, 16, 16, 0}); + AscendC::Mul(kvFp32Tensor[64], kvFp32Tensor[64], vTensor[64], mask, 128, {1, 1, 1, 16, 16, 0}); + AscendC::PipeBarrier(); + // KV_cache = KV_cur + KV_cache * kv_decay + AscendC::Add(kvCacheTensor, kvCacheTensor, kvFp32Tensor, eleCountPerKVCache_); + } else { + float decayLambda = decayFp32Buf_.Get().GetValue(headIndex); + auto kvCacheFp32Tensor = kvCacheFp32Buf_.Get(); + auto kFp32Tensor = kFp32Buf_.Get(); + auto vFp32Tensor = vFp32Buf_.Get(); + // cast kvCache to fp32 and multiply with decay + AscendC::Cast(kvCacheFp32Tensor, kvCacheTensor, RoundMode::CAST_NONE, eleCountPerKVCache_); + AscendC::PipeBarrier(); + AscendC::Muls(kvCacheFp32Tensor, kvCacheFp32Tensor, decayLambda, eleCountPerKVCache_); + + // KV_cur = Ki * Vi + AscendC::Cast(kFp32Tensor, kTensor, RoundMode::CAST_NONE, headDim_); + AscendC::Cast(vFp32Tensor, vTensor, RoundMode::CAST_NONE, headDim_); + AscendC::PipeBarrier(); + AscendC::BroadCast(kvFp32Tensor, kFp32Tensor, dstShape, srcShape, broadCastTensor); + AscendC::PipeBarrier(); + AscendC::Mul(kvFp32Tensor, kvFp32Tensor, vFp32Tensor, mask, 128, {1, 1, 1, 16, 16, 0}); + AscendC::Mul(kvFp32Tensor[64], kvFp32Tensor[64], vFp32Tensor[64], mask, 128, {1, 1, 1, 16, 16, 0}); + AscendC::PipeBarrier(); + // KV_cache = KV_cur + KV_cache * kv_decay + AscendC::Add(kvCacheFp32Tensor, kvCacheFp32Tensor, kvFp32Tensor, eleCountPerKVCache_); + } +} + +template +__aicore__ inline void LightningAttentionDecode::SaveKVCache(uint32_t kvCacheOffset) +{ + auto kvCacheTensor = kvCacheBuf_.Get(); + if constexpr (!IsSameType::value) { + auto kvCacheFp32Tensor = kvCacheFp32Buf_.Get(); + AscendC::Cast(kvCacheTensor, kvCacheFp32Tensor, RoundMode::CAST_ROUND, eleCountPerKVCache_); + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + } + AscendC::DataCopy(kvCacheOutGM_[kvCacheOffset], kvCacheTensor, eleCountPerKVCache_); +} + +template +__aicore__ inline void LightningAttentionDecode::WaitKVCacheSaved() +{ + int32_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); +} + +} // namespace LightningAttention + +#endif // LIGHTNING_ATTENTION_DECODE_H diff --git a/csrc/lightning_attention_prefill/lightning_attention_prefill_torch_adpt.h b/csrc/lightning_attention_prefill/lightning_attention_prefill_torch_adpt.h new file mode 100644 index 00000000000..1c627a1dd04 --- /dev/null +++ b/csrc/lightning_attention_prefill/lightning_attention_prefill_torch_adpt.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIGHTNING_ATTENTION_PREFILL_TORCH_ADPT_H +#define LIGHTNING_ATTENTION_PREFILL_TORCH_ADPT_H + +namespace vllm_ascend { + std::tuple npu_lightning_attention_prefill( + const at::Tensor &query, + const at::Tensor &key, + const at::Tensor &value, + const at::Tensor &slope_rate, + int64_t block_size, + const c10::optional &kv_history, + at::OptionalIntArrayRef actual_seq_len) + { + auto default_seq_len = std::vector(query.size(0), query.size(2)); + auto actual_seq_len_value = actual_seq_len.value_or(default_seq_len); + auto output_size_0 = {query.size(0), query.size(1), query.size(2), query.size(3)}; + auto output_size_1 = {query.size(0), query.size(1), query.size(3), query.size(3)}; + auto output_dtype_0 = query.scalar_type(); + at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0)); + at::Tensor kv_caches_out = at::empty(output_size_1, query.options().dtype(output_dtype_0)); + EXEC_NPU_CMD( + aclnnLightningAttentionPrefill, + query, + key, + value, + slope_rate, + kv_history, + block_size, + actual_seq_len_value, + "BNSD", + attention_out, + kv_caches_out); + return std::tuple(attention_out, kv_caches_out); + } +} + +#endif diff --git a/csrc/lightning_attention_prefill/op_host/CMakeLists.txt b/csrc/lightning_attention_prefill/op_host/CMakeLists.txt new file mode 100644 index 00000000000..49da515a733 --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME LightningAttentionPrefill + OPTIONS --cce-auto-sync + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnInner PRIVATE + lightning_attention_prefill_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_lightning_attention.cpp +) + +target_sources(optiling PRIVATE + lightning_attention_prefill_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + lightning_attention_prefill_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + lightning_attention_prefill_proto.cpp +) + +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/aclnn_lightning_attention_prefill.h + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) \ No newline at end of file diff --git a/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention.cpp b/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention.cpp new file mode 100644 index 00000000000..7dc4e642013 --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention.cpp @@ -0,0 +1,72 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/types.h" +#include "aclnn_lightning_attention_prefill.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +extern aclnnStatus aclnnInnerLightningAttentionPrefillGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + const aclTensor *kvHistoryOptional, + int64_t blockSize, + const aclIntArray *actualSeqLen, + char *inputLayoutOptional, + const aclTensor *attentionOut, + const aclTensor *kvCachesOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +extern aclnnStatus aclnnInnerLightningAttentionPrefill( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + + +aclnnStatus aclnnLightningAttentionPrefillGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + const aclTensor *kvHistoryOptional, + int64_t blockSize, + const aclIntArray *actualSeqLen, + char *inputLayoutOptional, + const aclTensor *attentionOut, + const aclTensor *kvCachesOut, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + aclnnStatus ret = aclnnInnerLightningAttentionPrefillGetWorkspaceSize( + query, key, value, slopeRate, kvHistoryOptional, blockSize, actualSeqLen, + inputLayoutOptional, attentionOut, kvCachesOut, workspaceSize, executor); + return ret; +} + +aclnnStatus aclnnLightningAttentionPrefill( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + aclnnStatus ret = aclnnInnerLightningAttentionPrefill(workspace, workspaceSize, executor, stream); + return ret; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention_prefill.h b/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention_prefill.h new file mode 100644 index 00000000000..6da0081ca0a --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/aclnn_lightning_attention_prefill.h @@ -0,0 +1,68 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ACLNN_LIGHTNING_ATTENTION_PREFILL_H_ +#define ACLNN_LIGHTNING_ATTENTION_PREFILL_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* function: aclnnLightningAttentionPrefillGetWorkspaceSize + * parameters : + * query : required + * key : required + * value : required + * slopeRate : required + * kvHistoryOptional : optional + * blockSize : required + * actualSeqLen : required + * inputLayoutOptional : optional + * attentionOut : required + * kvCachesOut : required + * workspaceSize : size of workspace(output). + * executor : executor context(output). + */ +__attribute__((visibility("default"))) +aclnnStatus aclnnLightningAttentionPrefillGetWorkspaceSize( + const aclTensor *query, + const aclTensor *key, + const aclTensor *value, + const aclTensor *slopeRate, + const aclTensor *kvHistoryOptional, + int64_t blockSize, + const aclIntArray *actualSeqLen, + char *inputLayoutOptional, + const aclTensor *attentionOut, + const aclTensor *kvCachesOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +/* function: aclnnLightningAttentionPrefill + * parameters : + * workspace : workspace memory addr(input). + * workspaceSize : size of workspace(input). + * executor : executor context(input). + * stream : acl stream. + */ +__attribute__((visibility("default"))) +aclnnStatus aclnnLightningAttentionPrefill( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_def.cpp b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_def.cpp new file mode 100644 index 00000000000..4611903ac36 --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_def.cpp @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "register/op_def_registry.h" + + +namespace ops { +class LightningAttentionPrefill : public OpDef { +public: + explicit LightningAttentionPrefill(const char* name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("slope_rate") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("kv_history") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("attention") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("kv_caches") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("block_size").Int(); + this->Attr("actual_seq_len").ListInt({}); + this->Attr("input_layout").AttrType(OPTIONAL).String("BNSD"); + + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + } +}; + +OP_ADD(LightningAttentionPrefill); +} diff --git a/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_proto.cpp b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_proto.cpp new file mode 100644 index 00000000000..4058dcea2a2 --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_proto.cpp @@ -0,0 +1,55 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "register/op_def_registry.h" + +using namespace ge; + + +namespace ops { +static constexpr size_t INDEX_IN_Q = 0; +static constexpr size_t INDEX_IN_K = 1; +static constexpr size_t INDEX_IN_V = 2; +static constexpr size_t INDEX_IN_SLP_RATE = 3; +static constexpr size_t INDEX_IN_KV_HIS = 4; +static constexpr size_t DIM_2 = 2; +static constexpr size_t DIM_3 = 3; +static constexpr size_t INDEX_OUT_ATTN = 0; +static constexpr size_t INDEX_OUT_KV_CACHES = 1; + +static ge::graphStatus InferShapeLightningAttentionPrefill(gert::InferShapeContext* context) +{ + const gert::Shape* q_shape = context->GetInputShape(INDEX_IN_Q); + gert::Shape* attn_out_shape = context->GetOutputShape(INDEX_OUT_ATTN); + gert::Shape* kv_caches_shape = context->GetOutputShape(INDEX_OUT_KV_CACHES); + *attn_out_shape = *q_shape; + + kv_caches_shape->SetDimNum(q_shape->GetDimNum()); + kv_caches_shape->SetDim(0, q_shape->GetDim(0)); + kv_caches_shape->SetDim(1, q_shape->GetDim(1)); + kv_caches_shape->SetDim(DIM_2, q_shape->GetDim(DIM_3)); + kv_caches_shape->SetDim(DIM_3, q_shape->GetDim(DIM_3)); + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeLightningAttentionPrefill(gert::InferDataTypeContext *context) +{ + const auto inputDataType = context->GetInputDataType(INDEX_IN_Q); + context->SetOutputDataType(INDEX_OUT_ATTN, inputDataType); + context->SetOutputDataType(INDEX_OUT_KV_CACHES, inputDataType); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(LightningAttentionPrefill) + .InferShape(InferShapeLightningAttentionPrefill) + .InferDataType(InferDataTypeLightningAttentionPrefill); + +} diff --git a/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.cpp b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.cpp new file mode 100644 index 00000000000..2d56732ee9a --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.cpp @@ -0,0 +1,345 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "lightning_attention_prefill_tiling.h" +#include "register/op_impl_registry.h" + +namespace optiling +{ + +static constexpr uint32_t MAX_BASE_M = 128; +static constexpr uint32_t MAX_BATCH_SIZE = 256; +static constexpr uint32_t MAX_AIV_NUM = 50; +static constexpr uint32_t ATTR_BLOCK_SIZE = 0; +static constexpr uint32_t ATTR_ACTUAL_SEQ_LEN_ARRAY = 1; +static constexpr uint32_t HALF_BYTE_SIZE = 2; +static constexpr uint32_t FLOAT_BYTE_SIZE = 4; +static constexpr size_t DIM_2 = 2; +static constexpr size_t DIM_3 = 3; + +bool LightningAttentionPrefillTiling::IsCapable() +{ + return true; +} + +ge::graphStatus LightningAttentionPrefillTiling::GetPlatformInfo() +{ + aicNum_ = ascendcPlatform_->GetCoreNumAic(); + aivNum_ = ascendcPlatform_->GetCoreNumAiv(); + actualUsedAivNum_ = aivNum_; + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::UB, aicoreParams_.ubSize); + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::L1, aicoreParams_.l1Size); + ascendcPlatform_->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, aicoreParams_.l0cSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionPrefillTiling::GetShapeAttrsInfo() +{ + auto attrs = context_->GetAttrs(); + auto *blockSize = attrs->GetAttrPointer(ATTR_BLOCK_SIZE); + blockSize_ = *blockSize; + tilingData_.laBaseParams.set_blockSize(blockSize_); + + auto *seqLenArray = attrs->GetListInt(ATTR_ACTUAL_SEQ_LEN_ARRAY); + auto qShape = context_->GetInputShape(0)->GetStorageShape(); + uint32_t batchSize = qShape.GetDim(0); + uint32_t headNum = qShape.GetDim(1); + uint32_t maxSeqLen = qShape.GetDim(DIM_2); + std::vector blockCountPerBatch(MAX_BATCH_SIZE, maxSeqLen / blockSize_); + std::vector tailBlockSize(MAX_BATCH_SIZE); + if (!seqLenArray || seqLenArray->GetSize() != batchSize) { + return ge::GRAPH_FAILED; + } + for (uint32_t index = 0; index < seqLenArray->GetSize(); ++index) { + tailBlockSize[index] = seqLenArray->GetData()[index] % blockSize_; + blockCountPerBatch[index] = (seqLenArray->GetData()[index] + blockSize_ - 1) / blockSize_; + totalBlockCount_ += blockCountPerBatch[index]; + } + totalBlockCount_ *= headNum; + tilingData_.laBaseParams.set_tailBlockSize(tailBlockSize.data()); + tilingData_.laBaseParams.set_blockCountPerBatch(blockCountPerBatch.data()); + + if (!AnalyzeDType()) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionPrefillTiling::DoOpTiling() +{ + auto qShape = context_->GetInputShape(0)->GetStorageShape(); + // set base params + tilingData_.laBaseParams.set_batchSize(qShape.GetDim(0)); + tilingData_.laBaseParams.set_headNum(qShape.GetDim(1)); + tilingData_.laBaseParams.set_maxSeqLen(qShape.GetDim(DIM_2)); + tilingData_.laBaseParams.set_headDim(qShape.GetDim(DIM_3)); + tilingData_.laBaseParams.set_eleCountPerHead(qShape.GetDim(DIM_2) * qShape.GetDim(DIM_3)); + tilingData_.laBaseParams.set_eleCountPerBlock(blockSize_ * qShape.GetDim(DIM_3)); + + qSBlockSize_ = blockSize_; + kvSBlockSize_ = blockSize_; + + headDimBlock_ = tilingData_.laBaseParams.get_headDim(); + + taskNum_ = tilingData_.laBaseParams.get_batchSize() * tilingData_.laBaseParams.get_headNum(); + if (taskNum_ < actualUsedAivNum_) { + actualUsedAivNum_ = taskNum_; + } + tilingData_.laBaseParams.set_actualUsedAivNum(actualUsedAivNum_); + + SetHeadStartEnd(); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionPrefillTiling::DoLibApiTiling() +{ + if (!SetMatmulTiling()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +uint64_t LightningAttentionPrefillTiling::GetTilingKey() const +{ + return 0; +} + +ge::graphStatus LightningAttentionPrefillTiling::GetWorkspaceSize() +{ + uint32_t headNum = tilingData_.laBaseParams.get_headNum(); + // workspace reserved for each core + // - p + // - oIntra + // - updatedKey + + auto dataSize = mm1InDType_ == matmul_tiling::DataType::DT_FLOAT ? FLOAT_BYTE_SIZE : HALF_BYTE_SIZE; + // workspace to store P, which is type float16/bfloat16/float32 with shape BLOCK_SIZE * BLOCK_SIZE + uint32_t pWorkspaceSize = dataSize * blockSize_ * blockSize_; + // workspace to store Ointra, which is type float with shape BLOCK_SIZE * HEAD_DIM + uint32_t oIntraWorkspaceSize = calcTypeSize_ * tilingData_.laBaseParams.get_eleCountPerBlock(); + // workspace to store O_inter/updated Ki, which is type float16/bfloat16/float32 with shape BLOCK_SIZE * HEAD_DIM + uint32_t updatedKeyWorkspaceSize = calcTypeSize_ * tilingData_.laBaseParams.get_eleCountPerBlock(); + workspaceSize_ += (pWorkspaceSize + oIntraWorkspaceSize + updatedKeyWorkspaceSize) * + actualUsedAivNum_; + + // workSpace shared by every core + // - diagDecay, type float with shape (HEAD, BlockSize, BlockSize) + uint32_t diagDecayWorkspaceSize = headNum * blockSize_ * blockSize_ * calcTypeSize_; + workspaceSize_ += diagDecayWorkspaceSize; + // - (qDecay + kDecay + blockDecay) * HEAD + uint32_t qDecayWorkspaceSize = blockSize_; + uint32_t kDecayWorkspaceSize = blockSize_; + uint32_t blockDecayWorkspaceSize = 8; + workspaceSize_ += headNum * (qDecayWorkspaceSize + kDecayWorkspaceSize + blockDecayWorkspaceSize); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningAttentionPrefillTiling::PostTiling() +{ + auto blockDim = CalcTschBlockDim(actualUsedAivNum_, aicNum_, aivNum_); + context_->SetBlockDim(blockDim); + size_t *currentWorkspace = context_->GetWorkspaceSizes(1); + currentWorkspace[0] = workspaceSize_ + ascendcPlatform_->GetLibApiWorkSpaceSize(); + tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), + context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); + return ge::GRAPH_SUCCESS; +} + + +bool LightningAttentionPrefillTiling::AnalyzeDType() +{ + inputDType_ = context_->GetInputDesc(0)->GetDataType(); + switch (inputDType_) { + case ge::DT_FLOAT16: + mm1InDType_ = matmul_tiling::DataType::DT_FLOAT16; + mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm2InDType_ = matmul_tiling::DataType::DT_FLOAT16; + mm2OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm3InDType_ = matmul_tiling::DataType::DT_FLOAT16; + mm3OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm4InDType_ = matmul_tiling::DataType::DT_FLOAT16; + mm4OutDType_ = matmul_tiling::DataType::DT_FLOAT; + calcTypeSize_ = ge::GetSizeByDataType(ge::DT_FLOAT); + break; + case ge::DT_BF16: + mm1InDType_ = matmul_tiling::DataType::DT_BF16; + mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm2InDType_ = matmul_tiling::DataType::DT_BF16; + mm2OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm3InDType_ = matmul_tiling::DataType::DT_BF16; + mm3OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm4InDType_ = matmul_tiling::DataType::DT_BF16; + mm4OutDType_ = matmul_tiling::DataType::DT_FLOAT; + calcTypeSize_ = ge::GetSizeByDataType(ge::DT_FLOAT); + break; + case ge::DT_FLOAT: + mm1InDType_ = matmul_tiling::DataType::DT_FLOAT; + mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm2InDType_ = matmul_tiling::DataType::DT_FLOAT; + mm2OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm3InDType_ = matmul_tiling::DataType::DT_FLOAT; + mm3OutDType_ = matmul_tiling::DataType::DT_FLOAT; + mm4InDType_ = matmul_tiling::DataType::DT_FLOAT; + mm4OutDType_ = matmul_tiling::DataType::DT_FLOAT; + calcTypeSize_ = ge::GetSizeByDataType(ge::DT_FLOAT); + break; + default: + return false; + } + return true; +} + +void LightningAttentionPrefillTiling::SetHeadStartEnd() +{ + uint32_t headStartIdx = 0; + uint32_t headEndIdx = 0; + uint32_t totalBlockCount = totalBlockCount_; + uint32_t blockCountEachCore; + std::vector headStart(MAX_AIV_NUM, 0); + std::vector headEnd(MAX_AIV_NUM, 0); + for (uint32_t coreId = 0, currBlockCount, batchId; coreId < actualUsedAivNum_; + ++coreId, headStartIdx = ++headEndIdx) { + blockCountEachCore = totalBlockCount / (actualUsedAivNum_ - coreId); + for (currBlockCount = 0u; taskNum_ - headEndIdx > actualUsedAivNum_ - coreId;) { + batchId = headEndIdx / tilingData_.laBaseParams.get_headNum(); + currBlockCount += tilingData_.laBaseParams.get_blockCountPerBatch()[batchId]; + if (currBlockCount >= blockCountEachCore) { + break; + } else { + ++headEndIdx; + } + } + totalBlockCount -= currBlockCount; + headStart[coreId] = (uint16_t)headStartIdx; + headEnd[coreId] = (uint16_t)headEndIdx; + } + tilingData_.laBaseParams.set_headStart(headStart.data()); + tilingData_.laBaseParams.set_headEnd(headEnd.data()); +} + +bool LightningAttentionPrefillTiling::SetMatmulTiling() +{ + return SetMatmulTilingForQXK() && SetMatmulTilingForPXV() && + SetMatmulTilingForQXKV() && SetMatmulTilingForKXV(); +} + +bool LightningAttentionPrefillTiling::SetMatmulTilingForQXK() +{ + matmul_tiling::MatmulApiTiling mm1(*ascendcPlatform_); + mm1.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm1InDType_, false); + mm1.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm1InDType_, true); + mm1.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm1OutDType_); + mm1.SetShape(qSBlockSize_, kvSBlockSize_, headDimBlock_); + mm1.SetOrgShape(qSBlockSize_, kvSBlockSize_, headDimBlock_, headDimBlock_); + mm1.SetBias(false); + if (mm1.SetBufferSpace(aicoreParams_.l1Size, aicoreParams_.l0cSize) != 0) { + return false; + } + if (mm1.SetFixSplit(std::min(qSBlockSize_, MAX_BASE_M), std::min(kvSBlockSize_, MAX_BASE_M)) != 0) { + return false; + } + if (mm1.GetTiling(tilingData_.mm1TilingData) != 0) { + return false; + } + tilingData_.mm1TilingData.set_stepM(1); + tilingData_.mm1TilingData.set_stepN(1); + return true; +} + +bool LightningAttentionPrefillTiling::SetMatmulTilingForPXV() +{ + matmul_tiling::MatmulApiTiling mm2(*ascendcPlatform_); + mm2.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm2InDType_, false); + mm2.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm2InDType_, true); + mm2.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm2OutDType_); + mm2.SetShape(qSBlockSize_, headDimBlock_, kvSBlockSize_); + mm2.SetOrgShape(qSBlockSize_, headDimBlock_, kvSBlockSize_, kvSBlockSize_); + mm2.SetBias(false); + if (mm2.SetBufferSpace(aicoreParams_.l1Size, aicoreParams_.l0cSize) != 0) { + return false; + } + if (mm2.SetFixSplit(std::min(qSBlockSize_, MAX_BASE_M), std::min(headDimBlock_, MAX_BASE_M)) != 0) { + return false; + } + if (mm2.GetTiling(tilingData_.mm2TilingData) != 0) { + return false; + } + tilingData_.mm2TilingData.set_stepM(1); + tilingData_.mm2TilingData.set_stepN(1); + return true; +} + +bool LightningAttentionPrefillTiling::SetMatmulTilingForQXKV() +{ + matmul_tiling::MatmulApiTiling mm3(*ascendcPlatform_); + mm3.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm3InDType_, false); + mm3.SetBType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm3InDType_, false); + mm3.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm3OutDType_); + mm3.SetShape(qSBlockSize_, headDimBlock_, headDimBlock_); + mm3.SetOrgShape(qSBlockSize_, headDimBlock_, headDimBlock_, headDimBlock_); + mm3.SetBias(false); + if (mm3.SetBufferSpace(aicoreParams_.l1Size, aicoreParams_.l0cSize) != 0) { + return false; + } + if (mm3.SetFixSplit(std::min(qSBlockSize_, MAX_BASE_M), std::min(headDimBlock_, MAX_BASE_M)) != 0) { + return false; + } + if (mm3.GetTiling(tilingData_.mm3TilingData) != 0) { + return false; + } + tilingData_.mm3TilingData.set_stepM(1); + tilingData_.mm3TilingData.set_stepN(1); + return true; +} + +bool LightningAttentionPrefillTiling::SetMatmulTilingForKXV() +{ + matmul_tiling::MatmulApiTiling mm4(*ascendcPlatform_); + mm4.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm4InDType_, true); + mm4.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, mm4InDType_, false); + mm4.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, mm4OutDType_); + mm4.SetShape(headDimBlock_, headDimBlock_, kvSBlockSize_); + mm4.SetOrgShape(headDimBlock_, headDimBlock_, kvSBlockSize_, kvSBlockSize_); + mm4.SetBias(false); + if (mm4.SetBufferSpace(aicoreParams_.l1Size, aicoreParams_.l0cSize) != 0) { + return false; + } + if (mm4.SetFixSplit(std::min(headDimBlock_, MAX_BASE_M), std::min(headDimBlock_, MAX_BASE_M)) != 0) { + return false; + } + if (mm4.GetTiling(tilingData_.mm4TilingData) != 0) { + return false; + } + tilingData_.mm4TilingData.set_stepM(1); + tilingData_.mm4TilingData.set_stepN(1); + return true; +} + + +ASCENDC_EXTERN_C ge::graphStatus TilingLightningAttentionPrefill(gert::TilingContext* context) +{ + LightningAttentionPrefillTiling tiling(context); + return tiling.DoTiling(); +} + +ASCENDC_EXTERN_C ge::graphStatus TilingPrepareForLightningAttentionPrefill(gert::TilingParseContext *context) +{ + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(LightningAttentionPrefill) + .Tiling(TilingLightningAttentionPrefill) + .TilingParse(TilingPrepareForLightningAttentionPrefill); + +} \ No newline at end of file diff --git a/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.h b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.h new file mode 100644 index 00000000000..0a54a9e3e0d --- /dev/null +++ b/csrc/lightning_attention_prefill/op_host/lightning_attention_prefill_tiling.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LIGHTNING_ATTENTION_TILING_H +#define LIGHTNING_ATTENTION_TILING_H + +#include + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "tiling/tiling_base.h" + +namespace optiling { + +BEGIN_TILING_DATA_DEF(LightningAttentionPrefillBaseParams) + TILING_DATA_FIELD_DEF(uint32_t, batchSize); + TILING_DATA_FIELD_DEF(uint32_t, headNum); + TILING_DATA_FIELD_DEF(uint32_t, maxSeqLen); + TILING_DATA_FIELD_DEF(uint32_t, headDim); + TILING_DATA_FIELD_DEF(uint32_t, blockSize); + TILING_DATA_FIELD_DEF(uint32_t, actualUsedAivNum); + TILING_DATA_FIELD_DEF(uint32_t, eleCountPerHead); + TILING_DATA_FIELD_DEF(uint32_t, eleCountPerBlock); + TILING_DATA_FIELD_DEF_ARR(uint16_t, 256, blockCountPerBatch); // max batch size 256 + TILING_DATA_FIELD_DEF_ARR(uint16_t, 256, tailBlockSize); // max batch size 256 + TILING_DATA_FIELD_DEF_ARR(uint16_t, 50, headStart); // max aiv num: 50 + TILING_DATA_FIELD_DEF_ARR(uint16_t, 50, headEnd);; +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(LightningAttentionPrefillBaseParamsOp, LightningAttentionPrefillBaseParams) + +BEGIN_TILING_DATA_DEF(LightningAttentionPrefillTilingData) + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm1TilingData); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm2TilingData); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm3TilingData); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm4TilingData); + TILING_DATA_FIELD_DEF_STRUCT(LightningAttentionPrefillBaseParams, laBaseParams); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(LightningAttentionPrefill, LightningAttentionPrefillTilingData) + +struct LightningAttentionPrefillCompileInfo {}; + +class LightningAttentionPrefillTiling : public TilingBaseClass { +public: + explicit LightningAttentionPrefillTiling(gert::TilingContext *context) + : TilingBaseClass(context) + { + ascendcPlatform_.reset(new platform_ascendc::PlatformAscendC(context->GetPlatformInfo())); + } +protected: + bool IsCapable() override; + + ge::graphStatus GetPlatformInfo() override; + + ge::graphStatus GetShapeAttrsInfo() override; + + ge::graphStatus DoOpTiling() override; + + ge::graphStatus DoLibApiTiling() override; + + uint64_t GetTilingKey() const override; + + ge::graphStatus GetWorkspaceSize() override; + + ge::graphStatus PostTiling() override; +private: + bool AnalyzeDType(); + void SetHeadStartEnd(); + bool SetMatmulTiling(); + bool SetMatmulTilingForQXK(); + bool SetMatmulTilingForPXV(); + bool SetMatmulTilingForQXKV(); + bool SetMatmulTilingForKXV(); +private: + LightningAttentionPrefillTilingData tilingData_; + ge::DataType inputDType_; + uint32_t blockSize_; + uint32_t calcTypeSize_; + uint32_t aicNum_; + uint32_t aivNum_; + uint32_t actualUsedAivNum_; + uint32_t taskNum_; + uint32_t totalBlockCount_ = 0; + matmul_tiling::DataType mm1InDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm1OutDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm2InDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm2OutDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm3InDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm3OutDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm4InDType_ = matmul_tiling::DataType::DT_FLOAT; + matmul_tiling::DataType mm4OutDType_ = matmul_tiling::DataType::DT_FLOAT; + + uint32_t qSBlockSize_; + uint32_t kvSBlockSize_; + uint32_t headDimBlock_; +}; + +} + +#endif \ No newline at end of file diff --git a/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.cpp b/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.cpp new file mode 100644 index 00000000000..ec3241201cc --- /dev/null +++ b/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.cpp @@ -0,0 +1,48 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "kernel_operator.h" +#include "lightning_attention_prefill.h" + +using namespace LightningAttention; + +#define COPY_TILING_DATA(tiling) \ + GET_TILING_DATA_WITH_STRUCT(LightningAttentionPrefillTilingData, tilingDataIn, tiling); \ + const LightningAttentionPrefillTilingData *__restrict tilingData = &tilingDataIn; \ + const TCubeTiling *__restrict mm1tiling = &(tilingData->mm1TilingData); \ + const TCubeTiling *__restrict mm2tiling = &(tilingData->mm2TilingData); \ + const TCubeTiling *__restrict mm3tiling = &(tilingData->mm3TilingData); \ + const TCubeTiling *__restrict mm4tiling = &(tilingData->mm4TilingData) + +extern "C" __global__ __aicore__ void lightning_attention_prefill( + GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, GM_ADDR kv_history, GM_ADDR attention_out, + GM_ADDR kv_caches, GM_ADDR workspace, GM_ADDR tiling) { + AscendC::TPipe pipe; + COPY_TILING_DATA(tiling); +#if (ORIG_DTYPE_QUERY == DT_FLOAT16) + LightningAttentionPrefill op; + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), op.mm1, mm1tiling, op.mm2, mm2tiling, op.mm3, mm3tiling, + op.mm4, mm4tiling); + op.Init(query, key, value, slope_rate, kv_history, attention_out, kv_caches, workspace, tilingData, &pipe); + op.Process(); +#elif (ORIG_DTYPE_QUERY == DT_BF16) + LightningAttentionPrefill op; + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), op.mm1, mm1tiling, op.mm2, mm2tiling, op.mm3, mm3tiling, + op.mm4, mm4tiling); + op.Init(query, key, value, slope_rate, kv_history, attention_out, kv_caches, workspace, tilingData, &pipe); + op.Process(); +#elif (ORIG_DTYPE_QUERY == DT_FLOAT) + LightningAttentionPrefill op; + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), op.mm1, mm1tiling, op.mm2, mm2tiling, op.mm3, mm3tiling, + op.mm4, mm4tiling); + op.Init(query, key, value, slope_rate, kv_history, attention_out, kv_caches, workspace, tilingData, &pipe); + op.Process(); +#endif +} \ No newline at end of file diff --git a/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.h b/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.h new file mode 100644 index 00000000000..2ed7f8f09d8 --- /dev/null +++ b/csrc/lightning_attention_prefill/op_kernel/lightning_attention_prefill.h @@ -0,0 +1,681 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LIGHTNING_ATTENTION_PREFILL_H +#define LIGHTNING_ATTENTION_PREFILL_H + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +using namespace matmul; + +namespace LightningAttention { + +template +class LightningAttentionPrefill { +public: + __aicore__ inline LightningAttentionPrefill() + { + } + __aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, GM_ADDR kv_history, + GM_ADDR attention_out, GM_ADDR kv_caches, GM_ADDR workspace, + const LightningAttentionPrefillTilingData *__restrict tiling, AscendC::TPipe *pipe); + __aicore__ inline void Process(); + +public: + // define matmul object for matmul(Q, Kt) + using a1Type = MatmulType; + using b1Type = MatmulType; + using c1Type = MatmulType; + using bias1Type = MatmulType; + Matmul mm1; + + // define matmul object for matmul(P, V) + using a2Type = MatmulType; + using b2Type = MatmulType; + using c2Type = MatmulType; + using bias2Type = MatmulType; + Matmul mm2; + + // define matmul object for matmul(Q, KV) + using a3Type = MatmulType; + using b3Type = MatmulType; + using c3Type = MatmulType; + using bias3Type = MatmulType; + Matmul mm3; + + // define matmul object for matmul(Kt, V) + using a4Type = MatmulType; + using b4Type = MatmulType; + using c4Type = MatmulType; + using bias4Type = MatmulType; + Matmul mm4; + +private: + __aicore__ inline void InitWorkspace(GM_ADDR workspace); + __aicore__ inline void InitMask(); + __aicore__ inline void GenerateMask(uint32_t headIdx, float s); + __aicore__ inline void GenerateDecay(uint32_t headIdx, float s); + __aicore__ inline void GenerateTailBlockDecay(const AscendC::LocalTensor &decayTensor, uint32_t batchId, + float s); + __aicore__ inline float GetSlope(uint32_t headIdx); + __aicore__ inline void ComputeEachBlock(uint32_t offset, uint32_t headIdx, + const AscendC::LocalTensor &decayTensor); + __aicore__ inline void ComputeOIntra(uint32_t offset, uint32_t headIdx); + __aicore__ inline void CopyMIn(uint32_t headIdx, uint32_t maskOffset, uint32_t copyRows); + __aicore__ inline void CopyDecayIn(uint32_t headIdx); + __aicore__ inline void ComputePSplit(uint32_t headIdx, uint32_t computeRound, + const AscendC::LocalTensor &pOutTensor); + __aicore__ inline void CopyPOut(uint32_t computeRound); + __aicore__ inline void ComputeOInter(uint32_t offset, const AscendC::LocalTensor &qDecayTensor); + __aicore__ inline void InitKVCache(uint32_t kvCacheOffset); + __aicore__ inline void UpdateKVCache(uint32_t offset, const AscendC::LocalTensor &kDecayTensor, + const AscendC::LocalTensor &blockDecayTensor); + __aicore__ inline void SaveKVCache(uint32_t kvCacheOffset); + __aicore__ inline void WaitKVCacheSaved(); + __aicore__ inline void CopyOIntraIn(uint32_t attentionOffset); + __aicore__ inline void CalculateOFinal(const AscendC::LocalTensor &oInterTensor, uint32_t attentionOffset); + __aicore__ inline void CopyAttentionOut(uint32_t attentionOffset); + +private: + AscendC::GlobalTensor queryGM_; + AscendC::GlobalTensor keyGM_; + AscendC::GlobalTensor valueGM_; + AscendC::GlobalTensor slopeRateGM_; + AscendC::GlobalTensor attentionOutGM_; + AscendC::GlobalTensor outPWorkspaceGM_; + AscendC::GlobalTensor outIntraWorkspaceGM_; + AscendC::GlobalTensor maskGM_; + AscendC::GlobalTensor decayGM_; + AscendC::GlobalTensor updatedKeyGM_; + AscendC::GlobalTensor oInterWorkspaceGM_; + AscendC::GlobalTensor kvCacheHistoryGM_; + AscendC::GlobalTensor kvCacheOutGM_; + + const LightningAttentionPrefillTilingData *__restrict tiling_; + uint32_t headNum_; + uint32_t headDim_; + uint32_t blockSize_; + uint32_t eleCountPerS_; + uint32_t eleCountPerSSplit_; + uint32_t mm1RoundM_; + uint32_t eleCountPerHead_; + uint32_t eleCountPerBlock_; + uint32_t eleCountPerOinterSplit_; + uint32_t eleCountPerKVCache_; + uint32_t currentCoreId_; + uint32_t maskMaxSize_; + uint32_t actualUsedAivNum_; + const uint16_t *blockCountPerBatch_; + const uint16_t *tailBlockSize_; + const uint16_t *headStart_; + const uint16_t *headEnd_; + uint32_t eleCountOFinal_; + + AscendC::TQue maskQueue_; + AscendC::TQue decayQueue_; + AscendC::TQue attentionOutQueue_; + AscendC::TQue pOutQueue_; + AscendC::TBuf kvCacheBuf_; + AscendC::TBuf kPrimeBuf_; + AscendC::TBuf<> castDataBuf_; + AscendC::TBuf<> decayHelp_; +}; + +template +__aicore__ inline void LightningAttentionPrefill::Init( + GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR slope_rate, GM_ADDR kv_history, GM_ADDR attention_out, + GM_ADDR kv_caches, GM_ADDR workspace, const LightningAttentionPrefillTilingData *__restrict tiling, + AscendC::TPipe *pipe) +{ + currentCoreId_ = GetBlockIdx(); + tiling_ = tiling; + headNum_ = tiling->laBaseParams.headNum; + headDim_ = tiling->laBaseParams.headDim; + blockSize_ = tiling->laBaseParams.blockSize; + eleCountPerS_ = blockSize_ * blockSize_; + eleCountPerSSplit_ = tiling->mm1TilingData.baseM * tiling->mm1TilingData.baseN; + mm1RoundM_ = tiling->mm1TilingData.singleCoreM / tiling->mm1TilingData.baseM; + eleCountPerHead_ = tiling->laBaseParams.eleCountPerHead; + eleCountPerBlock_ = tiling->laBaseParams.eleCountPerBlock; + eleCountPerOinterSplit_ = tiling->mm3TilingData.baseM * tiling->mm3TilingData.baseN; + eleCountPerKVCache_ = tiling->laBaseParams.headDim * tiling->laBaseParams.headDim; + actualUsedAivNum_ = tiling->laBaseParams.actualUsedAivNum; + blockCountPerBatch_ = tiling->laBaseParams.blockCountPerBatch; + tailBlockSize_ = tiling->laBaseParams.tailBlockSize; + headStart_ = tiling->laBaseParams.headStart; + headEnd_ = tiling->laBaseParams.headEnd; + + queryGM_.SetGlobalBuffer((__gm__ T *)query); + keyGM_.SetGlobalBuffer((__gm__ T *)key); + valueGM_.SetGlobalBuffer((__gm__ T *)value); + slopeRateGM_.SetGlobalBuffer((__gm__ T *)slope_rate); + attentionOutGM_.SetGlobalBuffer((__gm__ T *)attention_out); + kvCacheHistoryGM_.SetGlobalBuffer((__gm__ T *)kv_history); + kvCacheOutGM_.SetGlobalBuffer((__gm__ T *)kv_caches); + InitWorkspace(workspace); + + auto maxBufSize = 128 * 128; + maskMaxSize_ = 64 * 64; + eleCountOFinal_ = eleCountPerOinterSplit_ < maskMaxSize_ ? eleCountPerOinterSplit_ : maskMaxSize_; + pipe->InitBuffer(pOutQueue_, 1, sizeof(float) * maxBufSize); // 64k + if constexpr (!IsSameType::value) { + pipe->InitBuffer(castDataBuf_, sizeof(T) * maxBufSize); // 32k + } + pipe->InitBuffer(maskQueue_, 1, sizeof(float) * maskMaxSize_); // 16k + pipe->InitBuffer(decayQueue_, 1, sizeof(float) * (blockSize_ + blockSize_ + 8)); // 3k + pipe->InitBuffer(attentionOutQueue_, 1, sizeof(T) * maskMaxSize_); // 8k + pipe->InitBuffer(decayHelp_, sizeof(float) * 3 * blockSize_); // 3k + pipe->InitBuffer(kvCacheBuf_, sizeof(float) * maxBufSize); // 64k + pipe->InitBuffer(kPrimeBuf_, sizeof(T) * headDim_); // 0.25k + + InitMask(); +} + +template +__aicore__ inline void LightningAttentionPrefill::InitWorkspace(GM_ADDR workspace) +{ + // workspace reserved for each core + uint32_t pWorkSpaceSize = eleCountPerS_ * sizeof(T), oIntraWorkSpaceSize = eleCountPerBlock_ * sizeof(float), + updatedKeyWorkSpaceSize = eleCountPerBlock_ * sizeof(float), + baseWorkspaceOffset = (pWorkSpaceSize + oIntraWorkSpaceSize + updatedKeyWorkSpaceSize) * currentCoreId_; + outPWorkspaceGM_.SetGlobalBuffer((__gm__ T *)(workspace + baseWorkspaceOffset), eleCountPerS_); + outIntraWorkspaceGM_.SetGlobalBuffer((__gm__ float *)(workspace + baseWorkspaceOffset + pWorkSpaceSize), + eleCountPerBlock_); + auto updatedKeyOffset = baseWorkspaceOffset + pWorkSpaceSize + oIntraWorkSpaceSize; + updatedKeyGM_.SetGlobalBuffer((__gm__ T *)(workspace + updatedKeyOffset), eleCountPerBlock_); + // reuse same workspace + oInterWorkspaceGM_.SetGlobalBuffer((__gm__ float *)(workspace + updatedKeyOffset), eleCountPerBlock_); + + // workspace shared by each core + uint32_t sharedWorkspaceOffset = + (pWorkSpaceSize + oIntraWorkSpaceSize + updatedKeyWorkSpaceSize) * actualUsedAivNum_; + maskGM_.SetGlobalBuffer((__gm__ float *)(workspace + sharedWorkspaceOffset)); + decayGM_.SetGlobalBuffer( + (__gm__ float *)(workspace + sharedWorkspaceOffset + headNum_ * blockSize_ * blockSize_ * sizeof(float))); +} + +template +__aicore__ inline void LightningAttentionPrefill::InitMask() +{ + uint32_t curHeadIdx = currentCoreId_; + float s; + + while (curHeadIdx < headNum_) { + s = GetSlope(curHeadIdx); + GenerateMask(curHeadIdx, s); + GenerateDecay(curHeadIdx, s); + curHeadIdx += actualUsedAivNum_; + } + AscendC::SyncAll(); +} + +template +__aicore__ inline void LightningAttentionPrefill::GenerateMask(uint32_t headIdx, float s) +{ + // use kvCacheBuf (128, 128) as help tensor to generate mask then copy to GM + // blockSize greater than 128 shall iterate multiple times to get all mask values + auto helpTensor = kvCacheBuf_.Get(); + + uint32_t mOffset; + uint32_t tmp = 0xFF800000; // -inf + int32_t eleCountPerMask = blockSize_ * blockSize_; + int32_t eleCountPerHelp = 128 * 128; + int32_t iterTimes = (eleCountPerMask + eleCountPerHelp - 1) / (eleCountPerHelp); + int32_t blocksPerIter = blockSize_ / iterTimes; + int32_t eleCountPerIter = blocksPerIter * blockSize_; + + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + int32_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_V)); + int32_t eventIdMte3ToS = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_S)); + + for (int32_t iter = 0; iter < iterTimes; ++iter) { + AscendC::PipeBarrier(); + AscendC::Duplicate(helpTensor, *((float *)&tmp), eleCountPerIter); + AscendC::PipeBarrier(); + + for (int32_t b = 0; b < blocksPerIter; ++b) { + AscendC::PipeBarrier(); + AscendC::CreateVecIndex(helpTensor[b * blockSize_], (float)-(iter * blocksPerIter + b), + (iter * blocksPerIter + b) + 1); + AscendC::PipeBarrier(); + + // CreateVecIndex() will pad to 8 float elems, set padded values back to -inf + for (int32_t i = b + 1; i < 8; ++i) { + helpTensor.SetValue(iter * blocksPerIter + b * blockSize_ + i, *((float *)&tmp)); + } + } + + AscendC::PipeBarrier(); + AscendC::Muls(helpTensor, helpTensor, s, eleCountPerIter); + AscendC::PipeBarrier(); + AscendC::Exp(helpTensor, helpTensor, eleCountPerIter); + AscendC::PipeBarrier(); + + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + mOffset = headIdx * blockSize_ * blockSize_ + iter * eleCountPerIter; + AscendC::DataCopy(maskGM_[mOffset], helpTensor, eleCountPerIter); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + SetFlag(eventIdMte3ToS); + WaitFlag(eventIdMte3ToS); + } +} + +template +__aicore__ inline void LightningAttentionPrefill::Process() +{ + float s; + uint16_t absoluteHeadIdx = headStart_[currentCoreId_]; + uint32_t offset = absoluteHeadIdx * eleCountPerHead_, blockOffset; + uint32_t kvCacheOffset = absoluteHeadIdx * eleCountPerKVCache_; + bool isFirstLoop = true; + for (uint16_t relativeHeadIdx, batchId; absoluteHeadIdx <= headEnd_[currentCoreId_]; + ++absoluteHeadIdx, offset += eleCountPerHead_, kvCacheOffset += eleCountPerKVCache_) { + if (isFirstLoop) { + isFirstLoop = false; + } else { + WaitKVCacheSaved(); + } + InitKVCache(kvCacheOffset); + relativeHeadIdx = absoluteHeadIdx % headNum_; + batchId = absoluteHeadIdx / headNum_; + CopyDecayIn(relativeHeadIdx); + auto decayTensor = decayQueue_.DeQue(); + + blockOffset = offset; + for (uint32_t blockIdx = 0; blockIdx + 1 < blockCountPerBatch_[batchId]; ++blockIdx) { + ComputeEachBlock(blockOffset, relativeHeadIdx, decayTensor); + blockOffset += eleCountPerBlock_; + } + if (tailBlockSize_[batchId] != 0) { + GenerateTailBlockDecay(decayTensor, batchId, GetSlope(relativeHeadIdx)); + } + ComputeEachBlock(blockOffset, relativeHeadIdx, decayTensor); + decayQueue_.FreeTensor(decayTensor); + SaveKVCache(kvCacheOffset); + } +} + +template +__aicore__ inline void LightningAttentionPrefill::CopyDecayIn(uint32_t headIdx) +{ + auto decayLocal = decayQueue_.AllocTensor(); + uint32_t decayOffset = headIdx * (blockSize_ + blockSize_ + 8); + AscendC::DataCopy(decayLocal, decayGM_[decayOffset], blockSize_ + blockSize_ + 8); + decayQueue_.EnQue(decayLocal); +} + +template +__aicore__ inline float LightningAttentionPrefill::GetSlope(uint32_t headIdx) +{ + float s; + if constexpr (AscendC::IsSameType::value) { + s = AscendC::ToFloat(slopeRateGM_.GetValue(headIdx)); + } else { + s = (float)slopeRateGM_.GetValue(headIdx); + } + return s; +} + +template +__aicore__ inline void LightningAttentionPrefill::GenerateDecay(uint32_t headIdx, float s) +{ + auto helpTensor = decayHelp_.Get(); + + // q_decay + auto qDecayTensor = helpTensor; + AscendC::PipeBarrier(); + AscendC::CreateVecIndex(qDecayTensor, (float)1, blockSize_); + AscendC::PipeBarrier(); + AscendC::Muls(qDecayTensor, qDecayTensor, -s, blockSize_); + AscendC::PipeBarrier(); + AscendC::Exp(qDecayTensor, qDecayTensor, blockSize_); + AscendC::PipeBarrier(); + + // k_decay + auto kDecayTensor = helpTensor[blockSize_]; + AscendC::PipeBarrier(); + AscendC::CreateVecIndex(kDecayTensor, (float)1, blockSize_); + AscendC::PipeBarrier(); + AscendC::Muls(kDecayTensor, kDecayTensor, (float)-1, blockSize_); + AscendC::PipeBarrier(); + AscendC::Adds(kDecayTensor, kDecayTensor, (float)(int32_t)blockSize_, blockSize_); + AscendC::PipeBarrier(); + AscendC::Muls(kDecayTensor, kDecayTensor, -s, blockSize_); + AscendC::PipeBarrier(); + AscendC::Exp(kDecayTensor, kDecayTensor, blockSize_); + AscendC::PipeBarrier(); + + // block_decay + auto blockDecayTensor = helpTensor[blockSize_ + blockSize_]; + AscendC::PipeBarrier(); + AscendC::Duplicate(blockDecayTensor, -s * blockSize_, 8); + AscendC::PipeBarrier(); + AscendC::Exp(blockDecayTensor, blockDecayTensor, 8); + AscendC::PipeBarrier(); + + uint32_t decayOffset = headIdx * (blockSize_ + blockSize_ + 8); + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + int32_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_V)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + AscendC::DataCopy(decayGM_[decayOffset], helpTensor, blockSize_ + blockSize_ + 8); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); +} + +template +__aicore__ inline void LightningAttentionPrefill::GenerateTailBlockDecay( + const AscendC::LocalTensor &decayTensor, uint32_t batchId, float s) +{ + auto currentTailBlockSize = (int32_t)tailBlockSize_[batchId]; + int32_t tailSizePad = (currentTailBlockSize + 8 - 1) / 8 * 8; + auto kDecayTensor = decayTensor[blockSize_]; + AscendC::Duplicate(kDecayTensor, (float)0, blockSize_); + AscendC::PipeBarrier(); + AscendC::CreateVecIndex(kDecayTensor, (float)1, tailSizePad); + AscendC::PipeBarrier(); + AscendC::Muls(kDecayTensor, kDecayTensor, (float)-1, tailSizePad); + AscendC::PipeBarrier(); + AscendC::Adds(kDecayTensor, kDecayTensor, (float)currentTailBlockSize, tailSizePad); + AscendC::PipeBarrier(); + AscendC::Muls(kDecayTensor, kDecayTensor, -s, tailSizePad); + AscendC::PipeBarrier(); + AscendC::Exp(kDecayTensor, kDecayTensor, tailSizePad); + AscendC::PipeBarrier(); + + auto blockDecayTensor = decayTensor[blockSize_ + blockSize_]; + AscendC::Duplicate(blockDecayTensor, -s * currentTailBlockSize, 8); + AscendC::PipeBarrier(); + AscendC::Exp(blockDecayTensor, blockDecayTensor, 8); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void LightningAttentionPrefill::ComputeEachBlock(uint32_t offset, uint32_t headIdx, + const AscendC::LocalTensor &decayTensor) +{ + auto qDecayTensor = decayTensor; + auto kDecayTensor = decayTensor[blockSize_]; + auto blockDecayTensor = decayTensor[blockSize_ + blockSize_]; + + ComputeOIntra(offset, headIdx); + ComputeOInter(offset, qDecayTensor); + UpdateKVCache(offset, kDecayTensor, blockDecayTensor); +} + +template +__aicore__ inline void LightningAttentionPrefill::ComputeOIntra(uint32_t offset, uint32_t headIdx) +{ + // Step 1: calculate S = matmul(Q, K) + mm1.SetTensorA(queryGM_[offset]); + mm1.SetTensorB(keyGM_[offset], true); + for (uint32_t computeRound = 0; mm1.template Iterate(); ++computeRound) { + auto pOutTensor = pOutQueue_.AllocTensor(); + mm1.template GetTensorC(pOutTensor, false, true); + // Step 2: calculate P = mul(S, M) + ComputePSplit(headIdx, computeRound, pOutTensor); + CopyPOut(computeRound); + } + mm1.End(); + + // Step 3: calculate Ointra = matmul(P, V) + mm2.SetTensorA(outPWorkspaceGM_); + mm2.SetTensorB(valueGM_[offset]); + mm2.template IterateAll(outIntraWorkspaceGM_); + mm2.End(); +} + +template +__aicore__ inline void LightningAttentionPrefill::ComputeOInter(uint32_t offset, + const AscendC::LocalTensor &qDecayTensor) +{ + float qDecay; + uint32_t mm3BaseM = tiling_->mm3TilingData.baseM; + // Step 1: calculate O_inter = matmul(Q, KV) + auto kvCacheTensor = kvCacheBuf_.Get(); + mm3.SetWorkspace(oInterWorkspaceGM_); + mm3.SetTensorA(queryGM_[offset]); + if constexpr (IsSameType::value) { + mm3.SetTensorB(kvCacheTensor); + } else { + auto kvCastBuf = castDataBuf_.Get(); + AscendC::Cast(kvCastBuf, kvCacheTensor, RoundMode::CAST_ROUND, eleCountPerKVCache_); + mm3.SetTensorB(kvCastBuf); + } + mm3.template Iterate(); + for (uint32_t computeRound = 0, attentionBaseOffset = 0, totalRound = blockSize_ / mm3BaseM; + computeRound < totalRound; ++computeRound, attentionBaseOffset += eleCountPerOinterSplit_) { + auto oInterTensor = pOutQueue_.AllocTensor(); + mm3.template GetTensorC(oInterTensor, false, true); + // headDim <= 128, which means only M will split, N will not split + // Step 2: update O_inter with decay + for (uint32_t b = 0; b < mm3BaseM; b++) { + qDecay = qDecayTensor.GetValue(computeRound * mm3BaseM + b); + AscendC::PipeBarrier(); + AscendC::Muls(oInterTensor[b * headDim_], oInterTensor[b * headDim_], qDecay, headDim_); + AscendC::PipeBarrier(); + } + + for (uint32_t attentionRelativeOffset = 0; attentionRelativeOffset < eleCountPerOinterSplit_; + attentionRelativeOffset += eleCountOFinal_) { + CopyOIntraIn(attentionBaseOffset + attentionRelativeOffset); + // Step 3: Add O_inter and Cast + CalculateOFinal(oInterTensor, attentionRelativeOffset); + // Step 4: Save to O + CopyAttentionOut(offset + attentionBaseOffset + attentionRelativeOffset); + } + pOutQueue_.FreeTensor(oInterTensor); + } + mm3.End(); +} + +template +__aicore__ inline void LightningAttentionPrefill::InitKVCache(uint32_t kvCacheOffset) +{ + auto kvCacheTensor = kvCacheBuf_.Get(); + if (kvCacheHistoryGM_.GetPhyAddr() == nullptr) { + AscendC::Duplicate(kvCacheTensor, 0.0f, eleCountPerKVCache_); + } else { + if constexpr (IsSameType::value) { + AscendC::DataCopy(kvCacheTensor, kvCacheHistoryGM_[kvCacheOffset], eleCountPerKVCache_); + } else { + auto tmpBuf = kvCacheTensor[eleCountPerKVCache_ / 2].ReinterpretCast(); + AscendC::DataCopy(tmpBuf, kvCacheHistoryGM_[kvCacheOffset], eleCountPerKVCache_); + int32_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + AscendC::Cast(kvCacheTensor, tmpBuf, RoundMode::CAST_NONE, eleCountPerKVCache_); + } + } +} + + +template +__aicore__ inline void LightningAttentionPrefill::UpdateKVCache(uint32_t offset, + const AscendC::LocalTensor &kDecayTensor, + const AscendC::LocalTensor &blockDecayTensor) +{ + // Step 1: update & save K + auto kTensor = kPrimeBuf_.Get(); + int32_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_MTE2)); + int32_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_V)); + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + for (uint32_t b = 0; b < blockSize_; b++) { + float kDecay = kDecayTensor.GetValue(b); + AscendC::PipeBarrier(); + AscendC::DataCopy(kTensor, keyGM_[offset + b * headDim_], headDim_); + AscendC::PipeBarrier(); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + if constexpr (IsSameType::value) { + AscendC::Muls(kTensor, kTensor, kDecay, headDim_); + } else { + auto tmpBuf = castDataBuf_.Get(); + AscendC::Cast(tmpBuf, kTensor, RoundMode::CAST_NONE, headDim_); + AscendC::PipeBarrier(); + AscendC::Muls(tmpBuf, tmpBuf, kDecay, headDim_); + AscendC::PipeBarrier(); + AscendC::Cast(kTensor, tmpBuf, RoundMode::CAST_ROUND, headDim_); + } + + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + AscendC::DataCopy(updatedKeyGM_[b * headDim_], kTensor, headDim_); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + } + + // Step 2: update KV + auto kvCacheTensor = kvCacheBuf_.Get(); + float blockDecay = blockDecayTensor.GetValue(0); + AscendC::PipeBarrier(); + AscendC::Muls(kvCacheTensor, kvCacheTensor, blockDecay, eleCountPerKVCache_); + + // Step 3: calculate KV_cur = matmul(Kt, V) + mm4.SetTensorA(updatedKeyGM_, true); + mm4.SetTensorB(valueGM_[offset]); + // KV_cur shape is (headDim, headDim), which means matmul will finish in one round + if (mm4.template Iterate()) { + // Step 4: Add KV_cur + auto kvCurrentTensor = pOutQueue_.AllocTensor(); + mm4.template GetTensorC(kvCurrentTensor, false, true); + AscendC::PipeBarrier(); + AscendC::Add(kvCacheTensor, kvCacheTensor, kvCurrentTensor, eleCountPerKVCache_); + pOutQueue_.FreeTensor(kvCurrentTensor); + } + mm4.End(); +} + +template +__aicore__ inline void LightningAttentionPrefill::SaveKVCache(uint32_t kvCacheOffset) +{ + auto kvCacheTensor = kvCacheBuf_.Get(); + if constexpr (IsSameType::value) { + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + AscendC::DataCopy(kvCacheOutGM_[kvCacheOffset], kvCacheTensor, eleCountPerKVCache_); + } else { + auto tmpBuf = castDataBuf_.Get(); + int32_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_V)); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + AscendC::Cast(tmpBuf, kvCacheTensor, RoundMode::CAST_ROUND, eleCountPerKVCache_); + int32_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + AscendC::DataCopy(kvCacheOutGM_[kvCacheOffset], tmpBuf, eleCountPerKVCache_); + } +} + +template +__aicore__ inline void LightningAttentionPrefill::WaitKVCacheSaved() +{ + int32_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); +} + +template +__aicore__ inline void LightningAttentionPrefill::CopyOIntraIn(uint32_t attentionOffset) +{ + auto oIntraTensor = maskQueue_.AllocTensor(); + AscendC::DataCopy(oIntraTensor, outIntraWorkspaceGM_[attentionOffset], eleCountOFinal_); + maskQueue_.EnQue(oIntraTensor); +} + +template +__aicore__ inline void LightningAttentionPrefill::CalculateOFinal(const AscendC::LocalTensor &oInterTensor, + uint32_t attentionOffset) +{ + auto oIntraTensor = maskQueue_.DeQue(); + auto oFinalTensor = attentionOutQueue_.AllocTensor(); + if constexpr (IsSameType::value) { + AscendC::Add(oFinalTensor, oIntraTensor, oInterTensor[attentionOffset], eleCountOFinal_); + } else { + AscendC::Add(oIntraTensor, oIntraTensor, oInterTensor[attentionOffset], eleCountOFinal_); + AscendC::PipeBarrier(); + AscendC::Cast(oFinalTensor, oIntraTensor, RoundMode::CAST_ROUND, eleCountOFinal_); + } + + maskQueue_.FreeTensor(oIntraTensor); + attentionOutQueue_.EnQue(oFinalTensor); +} + +template +__aicore__ inline void LightningAttentionPrefill::CopyAttentionOut(uint32_t attentionOffset) +{ + auto oFinalTensor = attentionOutQueue_.DeQue(); + AscendC::DataCopy(attentionOutGM_[attentionOffset], oFinalTensor, eleCountOFinal_); + attentionOutQueue_.FreeTensor(oFinalTensor); +} + +template +__aicore__ inline void LightningAttentionPrefill::CopyMIn(uint32_t headIdx, uint32_t maskOffset, uint32_t copyRows) +{ + auto maskLocal = maskQueue_.AllocTensor(); + AscendC::PipeBarrier(); + AscendC::DataCopyParams copyParams{ + (uint16_t)copyRows, (uint16_t)(tiling_->mm1TilingData.baseN * sizeof(float) / DEFAULT_C0_SIZE), + (uint16_t)((tiling_->mm1TilingData.N - tiling_->mm1TilingData.baseN) * sizeof(float) / DEFAULT_C0_SIZE), + (uint16_t)0}; + AscendC::DataCopy(maskLocal, maskGM_[maskOffset], copyParams); + maskQueue_.EnQue(maskLocal); +} + +template +__aicore__ inline void LightningAttentionPrefill::ComputePSplit(uint32_t headIdx, uint32_t computeRound, + const AscendC::LocalTensor &pOutTensor) +{ + uint32_t multiplyRound = (eleCountPerSSplit_ + maskMaxSize_ - 1) / maskMaxSize_, + multiplyEleCount = eleCountPerSSplit_ < maskMaxSize_ ? eleCountPerSSplit_ : maskMaxSize_, + copyRows = multiplyEleCount / tiling_->mm1TilingData.baseN; + uint32_t maskOffset = headIdx * blockSize_ * blockSize_, pOutTensorOffset = 0; + maskOffset += computeRound % mm1RoundM_ * tiling_->mm1TilingData.baseM * tiling_->mm1TilingData.N + + computeRound / mm1RoundM_ * tiling_->mm1TilingData.baseN; + for (uint32_t multiplyRoundIdx = 0; multiplyRoundIdx < multiplyRound; + ++multiplyRoundIdx, pOutTensorOffset += multiplyEleCount, maskOffset += copyRows * tiling_->mm1TilingData.N) { + CopyMIn(headIdx, maskOffset, copyRows); + auto maskLocal = maskQueue_.DeQue(); + AscendC::Mul(pOutTensor[pOutTensorOffset], pOutTensor[pOutTensorOffset], maskLocal, multiplyEleCount); + maskQueue_.FreeTensor(maskLocal); + } + if constexpr (!IsSameType::value) { + auto tempOutTensor = castDataBuf_.Get(); + AscendC::Cast(tempOutTensor, pOutTensor, RoundMode::CAST_ROUND, eleCountPerSSplit_); + } + pOutQueue_.EnQue(pOutTensor); +} + +template +__aicore__ inline void LightningAttentionPrefill::CopyPOut(uint32_t computeRound) +{ + auto pOutTensor = pOutQueue_.DeQue(); + uint32_t offset = computeRound % mm1RoundM_ * tiling_->mm1TilingData.baseM * tiling_->mm1TilingData.N + + computeRound / mm1RoundM_ * tiling_->mm1TilingData.baseN; + AscendC::DataCopyParams copyParams{ + (uint16_t)tiling_->mm1TilingData.baseM, (uint16_t)(tiling_->mm1TilingData.baseN * sizeof(T) / DEFAULT_C0_SIZE), + (uint16_t)0, + (uint16_t)((tiling_->mm1TilingData.N - tiling_->mm1TilingData.baseN) * sizeof(T) / DEFAULT_C0_SIZE)}; + if constexpr (IsSameType::value) { + AscendC::DataCopy(outPWorkspaceGM_[offset], pOutTensor, copyParams); + } else { + auto tempOutTensor = castDataBuf_.Get(); + AscendC::DataCopy(outPWorkspaceGM_[offset], tempOutTensor, copyParams); + } + pOutQueue_.FreeTensor(pOutTensor); +} + +} // namespace LightningAttention + +#endif //LIGHTNING_ATTENTION_PREFILL_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 959f10b3b3e..8803d1c2e98 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -43,6 +43,8 @@ #include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h" #include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h" #include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h" +#include "lightning_attention_decode/lightning_attention_decode_torch_adpt.h" +#include "lightning_attention_prefill/lightning_attention_prefill_torch_adpt.h" #include #include #include @@ -927,4 +929,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int sparse_count=2048, int sparse_mode=3) -> Tensor" ); ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant); + + // lightning_attentioin_decode + ops.def( + "npu_lightning_attention_prefill(Tensor query, Tensor key, Tensor value, Tensor slope_rate, " + " int block_size, Tensor? kv_history=None, int[]? actual_seq_len=None) -> (Tensor, Tensor)" + ); + ops.impl("npu_lightning_attention_prefill", torch::kPrivateUse1, &vllm_ascend::npu_lightning_attention_prefill); + + // lightning_attention_prefill + ops.def( + "npu_lightning_attention_decode(Tensor query, Tensor key, Tensor value, Tensor kv_caches_ref, " + " Tensor slope_rate, Tensor slot_ids) -> Tensor" + ); + ops.impl("npu_lightning_attention_decode", torch::kPrivateUse1, &vllm_ascend::npu_lightning_attention_decode); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index ed7f7dc8b6c..d864a32a3de 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -569,6 +569,39 @@ at::Tensor npu_lightning_indexer_quant_meta( return lightning_indexer_quant_output; } +at::Tensor npu_lightning_attention_decode_meta( + const at::Tensor &query, + const at::Tensor &key, + const at::Tensor &value, + const at::Tensor &kv_caches_ref, + const at::Tensor &slope_rate, + const at::Tensor &slot_ids) +{ + auto output_size_0 = {query.size(0), query.size(1) * query.size(3)}; + auto output_dtype_0 = query.scalar_type(); + at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0)); + return attention_out; +} + +std::tuple npu_lightning_attention_prefill_meta( + const at::Tensor &query, + const at::Tensor &key, + const at::Tensor &value, + const at::Tensor &slope_rate, + int64_t block_size, + const c10::optional &kv_history, + at::OptionalIntArrayRef actual_seq_len) +{ + auto default_seq_len = std::vector(query.size(0), query.size(2)); + auto actual_seq_len_value = actual_seq_len.value_or(default_seq_len); + auto output_size_0 = {query.size(0), query.size(1), query.size(2), query.size(3)}; + auto output_size_1 = {query.size(0), query.size(1), query.size(3), query.size(3)}; + auto output_dtype_0 = query.scalar_type(); + at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0)); + at::Tensor kv_caches_out = at::empty(output_size_1, query.options().dtype(output_dtype_0)); + return std::tuple(attention_out, kv_caches_out); +} + } // namespace meta } // namespace vllm_ascend @@ -618,5 +651,9 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); // Lightning indexer quant ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta); + // lightning_attention_decode + ops.impl("npu_lightning_attention_decode", &vllm_ascend::meta::npu_lightning_attention_decode_meta); + // lightning_attention_prefill + ops.impl("npu_lightning_attention_prefill", &vllm_ascend::meta::npu_lightning_attention_prefill_meta); } } diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_decode.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_decode.py new file mode 100644 index 00000000000..6ca627c97b4 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_decode.py @@ -0,0 +1,124 @@ +import gc +import math +import copy +import torch +import torch_npu + +# enable vllm-ascend custom ops +from vllm_ascend.utils import enable_custom_op +enable_custom_op() + + +def build_decay(head_num): + # return decay rate with shape (head_num) + start = 2 ** (-(2 ** -(math.log2(head_num) - 3))) + ratio = start + return torch.tensor([start * ratio**i for i in range(head_num)]) + + +def lightning_attention_decode(q, k, v, kv_decay, kv_cache, dtype): + kv_cur = torch.outer(k, v) + kv_pre = kv_decay * kv_cache + kv = kv_cur + kv_pre + o = torch.matmul(q, kv) + return o, kv + + +def reference_lightning_attention_decode(query, key, value, slope_rate, kv_history, slot_ids, dtype): + # in_tensors[0]: Query (batch, head, 1, d) + # in_tensors[1]: Key (batch, head, 1, d) + # in_tensors[2]: Value (batch, head, 1, d) + # in_tensors[3]: Decay (head) + # in_tensors[4]: KV Caches (batch, head, d, d) + # in_tensors[5]: slot_ids (batch) + batch_num, head_num, _, d = query.shape + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + slope_rate = slope_rate.to(torch.float32) + kv_caches = kv_history.clone().to(torch.float32) + + # initialize O (batch, head * d) + output = torch.zeros(batch_num, head_num * d, dtype=dtype) + + for batchidx in range(batch_num): + slot_id = slot_ids[batchidx] + for headidx in range(head_num): + q = query[batchidx, headidx, 0, :] + k = key[batchidx, headidx, 0, :] + v = value[batchidx, headidx, 0, :] + kv_decay = math.exp(-slope_rate[headidx]) + kv_cache = kv_caches[slot_id, headidx, :, :] + o, kv = lightning_attention_decode(q, k, v, kv_decay, kv_cache, dtype) + output[batchidx, headidx*d:(headidx+1)*d] = o.to(dtype) + kv_caches[slot_id, headidx, :, :] = kv + + return output, kv_caches.to(dtype) + + +def execute_lightning_attention_decode_case(q_batch_size, kv_cache_batch, head_num, head_dim, + dtype=torch.float16): + query_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype) + key_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype) + value_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype) + slope_rate_cpu = build_decay(head_num).to(dtype) + kv_history_cpu = torch.randn(kv_cache_batch, head_num, head_dim, head_dim).to(dtype) + slot_ids_cpu = torch.arange(kv_cache_batch).to(torch.int32)[-q_batch_size:] + + query_npu = copy.deepcopy(query_cpu).npu() + key_npu = copy.deepcopy(key_cpu).npu() + value_npu = copy.deepcopy(value_cpu).npu() + slope_rate_npu = copy.deepcopy(slope_rate_cpu).npu() + kv_history_npu = copy.deepcopy(kv_history_cpu).npu() + slot_ids_npu = copy.deepcopy(slot_ids_cpu).npu() + + + # calculate on npu + attention_npu_out = torch.ops._C_ascend.npu_lightning_attention_decode( + query_npu, key_npu, value_npu, kv_history_npu, slope_rate_npu, slot_ids_npu) + + # calculate on cpu + attention_cpu_out, kv_cache_cpu_out = reference_lightning_attention_decode( + query_cpu, key_cpu, value_cpu, slope_rate_cpu, kv_history_cpu, slot_ids_cpu, dtype) + + # compare result + torch.testing.assert_close(attention_npu_out.cpu(), + attention_cpu_out, + atol=1e-9, + rtol=1e-6) + torch.testing.assert_close(kv_history_npu.cpu(), + kv_cache_cpu_out, + atol=1e-9, + rtol=1e-6) + + +@torch.inference_mode() +def test_lightning_attention_decode_same_batch(): + q_batch_size = 256 + head_num = 8 + head_dim = 128 + execute_lightning_attention_decode_case(q_batch_size, q_batch_size, head_num, head_dim) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_decode_different_batch(): + q_batch_size = 1 + kv_cache_batch = 256 + head_num = 8 + head_dim = 128 + execute_lightning_attention_decode_case(q_batch_size, kv_cache_batch, head_num, head_dim) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_decode_fp32(): + q_batch_size = 100 + head_num = 16 + head_dim = 128 + execute_lightning_attention_decode_case(q_batch_size, q_batch_size, head_num, head_dim, torch.float32) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_prefill.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_prefill.py new file mode 100644 index 00000000000..cba0b5ea206 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_lightning_attention_prefill.py @@ -0,0 +1,255 @@ +import gc +import math +import copy +import numpy as np +import torch +import torch_npu + +# enable vllm-ascend custom ops +from vllm_ascend.utils import enable_custom_op +enable_custom_op() + + +def build_decay(head_num): + # return decay rate with shape (head_num) + start = 2 ** (-(2 ** -(math.log2(head_num) - 3))) + ratio = start + return torch.tensor([start * ratio**i for i in range(head_num)]) + + +def lightning_attention_prefill(qt, kt, vt, kvsum, diag_decay, q_decay, block_decay, k_decay, dtype): + # O_intra = [(Q_t K_t^T) * M]V_t + qt_kt = torch.matmul(qt, torch.transpose(kt, 0, 1)) + qt_kt_mask = torch.mul(qt_kt, diag_decay).to(dtype) + o_intra = torch.matmul(qt_kt_mask.to(torch.float32), vt) + + # O_inter = Λ Q_t (KV) + o_inter = q_decay * torch.matmul(qt, kvsum.to(dtype).to(torch.float32)) + + # update KVsum + # KVsum = λ^B KVsum + (λ^B Λ^-1 K_t)^T V_t + kt = k_decay * kt + kt = kt.to(dtype) + kt_vt = torch.matmul(torch.transpose(kt, 0, 1).to(torch.float32), vt) + kvsum = torch.add(block_decay * kvsum, kt_vt) + + # O_t = O_intra + O_inter + o_t = torch.add(o_intra, o_inter) + + return o_t, kvsum + + +def reference_lightning_attention(q, k, v, ed, block_size, kv_history, seq_len): + dtype = q.dtype + batch_num, head_num, n, d = q.shape + if seq_len is None: + seq_len = [n] * batch_num + B = block_size + T = n // B + + # get Q, K, V, decay + # in_tensors[0]: Query without tiling (batch, head, n, d) + # in_tensors[1]: Key without tiing (batch, head, n, d) + # in_tensors[2]: Value without tiling (batch, head, n, d) + # in_tensors[3]: Decay (head) + query = q.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d) + key = k.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d) + value = v.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d) + decay = ed.to(torch.float32) # (head) + + # initialize O, KVsum + output = torch.zeros(batch_num, head_num, T, B, d, dtype=dtype) # (batch, head, T, B, d) + if kv_history is None: + kvsums = torch.zeros(batch_num, head_num, d, d, dtype=torch.float32) + else: + kvsums = kv_history.clone().to(torch.float32) # (batch, head, d, d) + + for batchidx in range(batch_num): + for headidx in range(head_num): + kvsum = kvsums[batchidx, headidx, :, :] + + # diag_decay: M with shape (B, B) + # q_decay: Λ with shape (B, 1) + # block_decay: λ^B with shape (1) + # k_decay: λ^B Λ^-1 with shape (B, 1) + s = decay[headidx] + i = torch.arange(B).view(B, 1) + j = torch.arange(B) + index = i - j + diag_decay = torch.exp(s * torch.where(index>=0, -index, float('-inf'))) + q_decay = torch.exp(-s * (j + 1)).reshape(B, 1) + block_decay = math.exp(-s * B) + k_decay = torch.exp(-s * (B - i - 1)) + + block_count = (seq_len[batchidx] + B - 1) // B + tail_block_size = seq_len[batchidx] % B + for t in range(block_count): + qt = query[batchidx, headidx, t, :, :] + kt = key[batchidx, headidx, t, :, :] + vt = value[batchidx, headidx, t, :, :] + if tail_block_size != 0 and t + 1 == block_count: + e = tail_block_size - i - 1 + e[tail_block_size:] = 0 + k_decay = torch.exp(-s * e) + block_decay = math.exp(-s * tail_block_size) + o_t, kvsum = lightning_attention_prefill( + qt, kt, vt, kvsum, diag_decay, q_decay, block_decay, k_decay, dtype) + output[batchidx, headidx, t, :, :] = o_t.to(dtype) + + kvsums[batchidx, headidx, :, :] = kvsum + + output = output.reshape(batch_num, head_num, n, d) # (batch, head, n, d) + kvsums = kvsums.to(dtype) + return [output, kvsums] + + +def execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, + has_kv_history=False, actual_seq_len=None, dtype=torch.float16, + slope_rate=None): + + base = 0.1 + query_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype) + key_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype) + value_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype) + if actual_seq_len: + for b in range(batch_size): + if actual_seq_len[b] < max_seq_len: + query_cpu[b,:, actual_seq_len[b]:,:] = 0 + key_cpu[b,:, actual_seq_len[b]:,:] = 0 + value_cpu[b,:, actual_seq_len[b]:,:] = 0 + + slope_rate_cpu = slope_rate + if slope_rate_cpu is None: + slope_rate_cpu = build_decay(head_num).to(dtype) + + query_npu = copy.deepcopy(query_cpu).npu() + key_npu = copy.deepcopy(key_cpu).npu() + value_npu = copy.deepcopy(value_cpu).npu() + slope_rate_npu = copy.deepcopy(slope_rate_cpu).npu() + kv_history_cpu = None + kv_history_npu = None + if has_kv_history: + kv_history_cpu = base * torch.randn(batch_size, head_num, head_dim, head_dim).to(dtype) + kv_history_npu = copy.deepcopy(kv_history_cpu).npu() + + # calculate on npu + attention_npu_out, kv_cache_npu_out = torch.ops._C_ascend.npu_lightning_attention_prefill( + query_npu, key_npu, value_npu, slope_rate_npu, block_size, kv_history_npu, actual_seq_len) + + # calculate on cpu + attention_cpu_out, kv_cache_cpu_out = reference_lightning_attention( + query_cpu, key_cpu, value_cpu, slope_rate_cpu, block_size, kv_history_cpu, actual_seq_len) + + if actual_seq_len: + for b in range(batch_size): + if actual_seq_len[b] < max_seq_len: + # npu default value may not be 0 + attention_npu_out[b,:, actual_seq_len[b]:,:] = 0 + + # compare result + torch.testing.assert_close(attention_npu_out.cpu(), + attention_cpu_out, + atol=1e-3, + rtol=1e-3) + torch.testing.assert_close(kv_cache_npu_out.cpu(), + kv_cache_cpu_out, + atol=1e-3, + rtol=1e-3) + + +@torch.inference_mode() +def test_lightning_attention_prefill_pad(): + batch_size = 1 + head_num = 4 + max_seq_len = 8192 + head_dim = 128 + block_size = 128 + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_prefill_unpad_1(): + batch_size = 1 + head_num = 8 + max_seq_len = 16 + block_size = 16 + head_dim = 128 + actual_seq_len = [5] + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False, + actual_seq_len) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() +def test_lightning_attention_prefill_unpad_2(): + batch_size = 4 + head_num = 8 + max_seq_len = 2048 + block_size = 128 + head_dim = 128 + actual_seq_len = [np.random.randint(1, max_seq_len / block_size + 1) * block_size + for _ in range(batch_size)] + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, + False, actual_seq_len) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_prefill_unpad_3(): + batch_size = 3 + head_num = 8 + max_seq_len = 384 + block_size = 128 + head_dim = 128 + actual_seq_len = [351, 129, 384] + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False, + actual_seq_len) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_prefill_unpad_4(): + batch_size = 1 + head_num = 4 + max_seq_len = 256 + block_size = 256 + head_dim = 128 + actual_seq_len = [5] + slope_rate = torch.tensor([0.9170, 0.8409, 0.7711, 0.7071], dtype=torch.float16) + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False, + actual_seq_len, torch.float16, slope_rate) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_prefill_with_kv_history(): + batch_size = 4 + head_num = 8 + max_seq_len = 1024 + head_dim = 128 + block_size = 128 + actual_seq_len = [np.random.randint(1, max_seq_len / block_size + 1) * block_size + for _ in range(batch_size)] + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, + True, actual_seq_len) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test_lightning_attention_prefill_fp32(): + batch_size = 1 + head_num = 16 + max_seq_len = 256 + head_dim = 128 + block_size = 128 + actual_seq_len = [130] + execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, + True, actual_seq_len, torch.float32) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats()