-
Notifications
You must be signed in to change notification settings - Fork 975
[WIP][Ops] Add AscendC Custom Op for Lightning Attention #7590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f247804
946d970
8f90f4d
38bfc20
e413fde
563f720
eabc966
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| /* | ||
| * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #ifndef LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H | ||
| #define LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H | ||
|
|
||
| namespace vllm_ascend { | ||
| at::Tensor npu_lightning_attention_decode( | ||
| const at::Tensor &query, | ||
| const at::Tensor &key, | ||
| const at::Tensor &value, | ||
| const at::Tensor &kv_caches_ref, | ||
| const at::Tensor &slope_rate, | ||
| const at::Tensor &slot_ids) | ||
| { | ||
| auto output_size_0 = {query.size(0), query.size(1) * query.size(3)}; | ||
| auto output_dtype_0 = query.scalar_type(); | ||
| at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0)); | ||
| EXEC_NPU_CMD( | ||
| aclnnLightningAttentionDecode, | ||
| query, | ||
| key, | ||
| value, | ||
| slope_rate, | ||
| kv_caches_ref, | ||
| slot_ids, | ||
| "BNSD", | ||
| attention_out); | ||
| return attention_out; | ||
| } | ||
| } | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. | ||
| # This file is a part of the CANN Open Software. | ||
| # Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). | ||
| # Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| # See LICENSE in the root of the software repository for the full text of the License. | ||
| # ====================================================================================================================== | ||
|
|
||
| add_ops_compile_options( | ||
| OP_NAME LightningAttentionDecode | ||
| OPTIONS --cce-auto-sync | ||
| -Wno-deprecated-declarations | ||
| -Werror | ||
| ) | ||
|
|
||
| target_sources(op_host_aclnnInner PRIVATE | ||
| lightning_attention_decode_def.cpp | ||
| ) | ||
|
|
||
| target_sources(opapi PRIVATE | ||
| aclnn_lightning_attention.cpp | ||
| ) | ||
|
|
||
| target_sources(optiling PRIVATE | ||
| lightning_attention_decode_tiling.cpp | ||
| ) | ||
|
|
||
| if (NOT BUILD_OPEN_PROJECT) | ||
| target_sources(opmaster_ct PRIVATE | ||
| lightning_attention_decode_tiling.cpp | ||
| ) | ||
| endif () | ||
|
|
||
| target_include_directories(optiling PRIVATE | ||
| ${CMAKE_CURRENT_SOURCE_DIR} | ||
| ) | ||
|
|
||
| target_sources(opsproto PRIVATE | ||
| lightning_attention_decode_proto.cpp | ||
| ) | ||
|
|
||
| install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/aclnn_lightning_attention_decode.h | ||
| DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /** | ||
| * Copyright (c) 2025 Huawei Technologies Co., Ltd. | ||
| * This file is a part of the CANN Open Software. | ||
| * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). | ||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| * See LICENSE in the root of the software repository for the full text of the License. | ||
| */ | ||
|
|
||
| #include "graph/types.h" | ||
| #include "aclnn_lightning_attention_decode.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| extern aclnnStatus aclnnInnerLightningAttentionDecodeGetWorkspaceSize( | ||
| const aclTensor *query, | ||
| const aclTensor *key, | ||
| const aclTensor *value, | ||
| const aclTensor *slopeRate, | ||
| aclTensor *kvCachesRef, | ||
| const aclTensor *slotIds, | ||
| char *inputLayoutOptional, | ||
| const aclTensor *attentionOut, | ||
| uint64_t *workspaceSize, | ||
| aclOpExecutor **executor); | ||
|
|
||
| extern aclnnStatus aclnnInnerLightningAttentionDecode( | ||
| void *workspace, | ||
| uint64_t workspaceSize, | ||
| aclOpExecutor *executor, | ||
| aclrtStream stream); | ||
|
|
||
| aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize( | ||
| const aclTensor *query, | ||
| const aclTensor *key, | ||
| const aclTensor *value, | ||
| const aclTensor *slopeRate, | ||
| aclTensor *kvCachesRef, | ||
| const aclTensor *slotIds, | ||
| char *inputLayoutOptional, | ||
| const aclTensor *attentionOut, | ||
| uint64_t *workspaceSize, | ||
| aclOpExecutor **executor) | ||
| { | ||
| aclnnStatus ret = aclnnInnerLightningAttentionDecodeGetWorkspaceSize( | ||
| query, key, value, slopeRate, kvCachesRef, slotIds, | ||
| inputLayoutOptional, attentionOut, workspaceSize, executor); | ||
| return ret; | ||
| } | ||
|
|
||
| aclnnStatus aclnnLightningAttentionDecode( | ||
| void *workspace, | ||
| uint64_t workspaceSize, | ||
| aclOpExecutor *executor, | ||
| aclrtStream stream) | ||
| { | ||
| aclnnStatus ret = aclnnInnerLightningAttentionDecode(workspace, workspaceSize, executor, stream); | ||
| return ret; | ||
| } | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| /** | ||
| * Copyright (c) 2025 Huawei Technologies Co., Ltd. | ||
| * This file is a part of the CANN Open Software. | ||
| * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). | ||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| * See LICENSE in the root of the software repository for the full text of the License. | ||
| */ | ||
|
|
||
| #ifndef ACLNN_LIGHTNING_ATTENTION_DECODE_H_ | ||
| #define ACLNN_LIGHTNING_ATTENTION_DECODE_H_ | ||
|
|
||
| #include "aclnn/acl_meta.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /* function: aclnnLightningAttentionDecodeGetWorkspaceSize | ||
| * parameters : | ||
| * query : required | ||
| * key : required | ||
| * value : required | ||
| * slopeRate : required | ||
| * kvCachesRef : required | ||
| * slotIds : required | ||
| * inputLayoutOptional : optional | ||
| * attentionOut : required | ||
| * kvCachesRef : required | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * workspaceSize : size of workspace(output). | ||
| * executor : executor context(output). | ||
| */ | ||
| __attribute__((visibility("default"))) | ||
| aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize( | ||
| const aclTensor *query, | ||
| const aclTensor *key, | ||
| const aclTensor *value, | ||
| const aclTensor *slopeRate, | ||
| aclTensor *kvCachesRef, | ||
| const aclTensor *slotIds, | ||
| char *inputLayoutOptional, | ||
| const aclTensor *attentionOut, | ||
| uint64_t *workspaceSize, | ||
| aclOpExecutor **executor); | ||
|
|
||
| /* function: aclnnLightningAttentionDecode | ||
| * parameters : | ||
| * workspace : workspace memory addr(input). | ||
| * workspaceSize : size of workspace(input). | ||
| * executor : executor context(input). | ||
| * stream : acl stream. | ||
| */ | ||
| __attribute__((visibility("default"))) | ||
| aclnnStatus aclnnLightningAttentionDecode( | ||
| void *workspace, | ||
| uint64_t workspaceSize, | ||
| aclOpExecutor *executor, | ||
| aclrtStream stream); | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /** | ||
| * Copyright (c) 2025 Huawei Technologies Co., Ltd. | ||
| * This file is a part of the CANN Open Software. | ||
| * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). | ||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| * See LICENSE in the root of the software repository for the full text of the License. | ||
| */ | ||
|
|
||
| #include "register/op_def_registry.h" | ||
|
|
||
| namespace ops { | ||
|
|
||
| class LightningAttentionDecode : public OpDef { | ||
| public: | ||
| explicit LightningAttentionDecode(const char* name) : OpDef(name) | ||
| { | ||
| this->Input("query") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Input("key") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Input("value") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Input("slope_rate") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Input("kv_caches") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Input("slot_ids") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Output("attention") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Output("kv_caches") | ||
| .ParamType(REQUIRED) | ||
| .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) | ||
| .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) | ||
| .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); | ||
| this->Attr("input_layout").AttrType(OPTIONAL).String("BNSD"); | ||
|
|
||
| this->AICore().AddConfig("ascend910b"); | ||
| this->AICore().AddConfig("ascend910_93"); | ||
| } | ||
| }; | ||
|
|
||
| OP_ADD(LightningAttentionDecode); | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,54 @@ | ||||||||||
| /** | ||||||||||
| * Copyright (c) 2025 Huawei Technologies Co., Ltd. | ||||||||||
| * This file is a part of the CANN Open Software. | ||||||||||
| * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). | ||||||||||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||||||||||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||||||||||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||||||||||
| * See LICENSE in the root of the software repository for the full text of the License. | ||||||||||
| */ | ||||||||||
|
|
||||||||||
| #include "register/op_def_registry.h" | ||||||||||
|
|
||||||||||
| namespace ops { | ||||||||||
|
|
||||||||||
| static constexpr size_t INDEX_IN_Q = 0; | ||||||||||
| static constexpr size_t INDEX_IN_K = 1; | ||||||||||
| static constexpr size_t INDEX_IN_V = 2; | ||||||||||
| static constexpr size_t INDEX_IN_SLP_RATE = 3; | ||||||||||
| static constexpr size_t INDEX_IN_KV_HIS = 4; | ||||||||||
| static constexpr size_t INDEX_IN_SLT_IDS = 5; | ||||||||||
| static constexpr size_t DIM_2 = 2; | ||||||||||
| static constexpr size_t DIM_3 = 3; | ||||||||||
| static constexpr size_t INDEX_OUT_ATTN = 0; | ||||||||||
| static constexpr size_t INDEX_OUT_KV_CACHES = 1; | ||||||||||
|
|
||||||||||
| static ge::graphStatus InferShapeLightningAttentionDecode(gert::InferShapeContext* context) | ||||||||||
| { | ||||||||||
| const gert::Shape* q_shape = context->GetInputShape(INDEX_IN_Q); | ||||||||||
| gert::Shape* attn_out_shape = context->GetOutputShape(INDEX_OUT_ATTN); | ||||||||||
| gert::Shape* kv_caches_shape = context->GetOutputShape(INDEX_OUT_KV_CACHES); | ||||||||||
| *attn_out_shape = *q_shape; | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape inference for
Suggested change
|
||||||||||
|
|
||||||||||
| kv_caches_shape->SetDimNum(q_shape->GetDimNum()); | ||||||||||
| kv_caches_shape->SetDim(0, q_shape->GetDim(0)); | ||||||||||
| kv_caches_shape->SetDim(1, q_shape->GetDim(1)); | ||||||||||
| kv_caches_shape->SetDim(DIM_2, q_shape->GetDim(DIM_3)); | ||||||||||
| kv_caches_shape->SetDim(DIM_3, q_shape->GetDim(DIM_3)); | ||||||||||
|
|
||||||||||
| return ge::GRAPH_SUCCESS; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| static ge::graphStatus InferDataTypeLightningAttentionDecode(gert::InferDataTypeContext *context) | ||||||||||
| { | ||||||||||
| const auto inputDataType = context->GetInputDataType(INDEX_IN_Q); | ||||||||||
| context->SetOutputDataType(INDEX_OUT_ATTN, inputDataType); | ||||||||||
| context->SetOutputDataType(INDEX_OUT_KV_CACHES, inputDataType); | ||||||||||
| return ge::GRAPH_SUCCESS; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| IMPL_OP_INFERSHAPE(LightningAttentionDecode) | ||||||||||
| .InferShape(InferShapeLightningAttentionDecode) | ||||||||||
| .InferDataType(InferDataTypeLightningAttentionDecode); | ||||||||||
|
|
||||||||||
| } | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a typo in this file's name:
lightning_attention_docode_torch_adpt.h. It should likely belightning_attention_decode_torch_adpt.h. Please rename the file for consistency and to avoid confusion.