Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion csrc/build_aclnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}


CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;"
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;lightning_attention_decode;lightning_attention_prefill"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
Expand Down Expand Up @@ -68,6 +68,8 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"causal_conv1d"
"moe_grouped_matmul"
"lightning_indexer_quant"
"lightning_attention_decode"
"lightning_attention_prefill"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There appears to be a typo in this file's name: lightning_attention_docode_torch_adpt.h. It should likely be lightning_attention_decode_torch_adpt.h. Please rename the file for consistency and to avoid confusion.

* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H
#define LIGHTNING_ATTENTION_DECODE_TORCH_ADPT_H

namespace vllm_ascend {
at::Tensor npu_lightning_attention_decode(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &kv_caches_ref,
const at::Tensor &slope_rate,
const at::Tensor &slot_ids)
{
auto output_size_0 = {query.size(0), query.size(1) * query.size(3)};
auto output_dtype_0 = query.scalar_type();
at::Tensor attention_out = at::empty(output_size_0, query.options().dtype(output_dtype_0));
EXEC_NPU_CMD(
aclnnLightningAttentionDecode,
query,
key,
value,
slope_rate,
kv_caches_ref,
slot_ids,
"BNSD",
attention_out);
return attention_out;
}
}

#endif
45 changes: 45 additions & 0 deletions csrc/lightning_attention_decode/op_host/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================

add_ops_compile_options(
OP_NAME LightningAttentionDecode
OPTIONS --cce-auto-sync
-Wno-deprecated-declarations
-Werror
)

target_sources(op_host_aclnnInner PRIVATE
lightning_attention_decode_def.cpp
)

target_sources(opapi PRIVATE
aclnn_lightning_attention.cpp
)

target_sources(optiling PRIVATE
lightning_attention_decode_tiling.cpp
)

if (NOT BUILD_OPEN_PROJECT)
target_sources(opmaster_ct PRIVATE
lightning_attention_decode_tiling.cpp
)
endif ()

target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)

target_sources(opsproto PRIVATE
lightning_attention_decode_proto.cpp
)

install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/aclnn_lightning_attention_decode.h
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#include "graph/types.h"
#include "aclnn_lightning_attention_decode.h"

#ifdef __cplusplus
extern "C" {
#endif

extern aclnnStatus aclnnInnerLightningAttentionDecodeGetWorkspaceSize(
const aclTensor *query,
const aclTensor *key,
const aclTensor *value,
const aclTensor *slopeRate,
aclTensor *kvCachesRef,
const aclTensor *slotIds,
char *inputLayoutOptional,
const aclTensor *attentionOut,
uint64_t *workspaceSize,
aclOpExecutor **executor);

extern aclnnStatus aclnnInnerLightningAttentionDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize(
const aclTensor *query,
const aclTensor *key,
const aclTensor *value,
const aclTensor *slopeRate,
aclTensor *kvCachesRef,
const aclTensor *slotIds,
char *inputLayoutOptional,
const aclTensor *attentionOut,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
aclnnStatus ret = aclnnInnerLightningAttentionDecodeGetWorkspaceSize(
query, key, value, slopeRate, kvCachesRef, slotIds,
inputLayoutOptional, attentionOut, workspaceSize, executor);
return ret;
}

aclnnStatus aclnnLightningAttentionDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
aclnnStatus ret = aclnnInnerLightningAttentionDecode(workspace, workspaceSize, executor, stream);
return ret;
}

#ifdef __cplusplus
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#ifndef ACLNN_LIGHTNING_ATTENTION_DECODE_H_
#define ACLNN_LIGHTNING_ATTENTION_DECODE_H_

#include "aclnn/acl_meta.h"

#ifdef __cplusplus
extern "C" {
#endif

/* function: aclnnLightningAttentionDecodeGetWorkspaceSize
* parameters :
* query : required
* key : required
* value : required
* slopeRate : required
* kvCachesRef : required
* slotIds : required
* inputLayoutOptional : optional
* attentionOut : required
* kvCachesRef : required
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameter kvCachesRef is documented twice (here and on line 26). Please remove this duplicate line to improve documentation clarity.

* workspaceSize : size of workspace(output).
* executor : executor context(output).
*/
__attribute__((visibility("default")))
aclnnStatus aclnnLightningAttentionDecodeGetWorkspaceSize(
const aclTensor *query,
const aclTensor *key,
const aclTensor *value,
const aclTensor *slopeRate,
aclTensor *kvCachesRef,
const aclTensor *slotIds,
char *inputLayoutOptional,
const aclTensor *attentionOut,
uint64_t *workspaceSize,
aclOpExecutor **executor);

/* function: aclnnLightningAttentionDecode
* parameters :
* workspace : workspace memory addr(input).
* workspaceSize : size of workspace(input).
* executor : executor context(input).
* stream : acl stream.
*/
__attribute__((visibility("default")))
aclnnStatus aclnnLightningAttentionDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

#ifdef __cplusplus
}
#endif

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#include "register/op_def_registry.h"

namespace ops {

class LightningAttentionDecode : public OpDef {
public:
explicit LightningAttentionDecode(const char* name) : OpDef(name)
{
this->Input("query")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("key")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("slope_rate")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("kv_caches")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("slot_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("attention")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("kv_caches")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("input_layout").AttrType(OPTIONAL).String("BNSD");

this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};

OP_ADD(LightningAttentionDecode);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#include "register/op_def_registry.h"

namespace ops {

static constexpr size_t INDEX_IN_Q = 0;
static constexpr size_t INDEX_IN_K = 1;
static constexpr size_t INDEX_IN_V = 2;
static constexpr size_t INDEX_IN_SLP_RATE = 3;
static constexpr size_t INDEX_IN_KV_HIS = 4;
static constexpr size_t INDEX_IN_SLT_IDS = 5;
static constexpr size_t DIM_2 = 2;
static constexpr size_t DIM_3 = 3;
static constexpr size_t INDEX_OUT_ATTN = 0;
static constexpr size_t INDEX_OUT_KV_CACHES = 1;

static ge::graphStatus InferShapeLightningAttentionDecode(gert::InferShapeContext* context)
{
const gert::Shape* q_shape = context->GetInputShape(INDEX_IN_Q);
gert::Shape* attn_out_shape = context->GetOutputShape(INDEX_OUT_ATTN);
gert::Shape* kv_caches_shape = context->GetOutputShape(INDEX_OUT_KV_CACHES);
*attn_out_shape = *q_shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The shape inference for attention_out is incorrect. The output shape should be 2D (batch, head_num * head_dim), but it's currently being set to the 4D shape of the query tensor. This will lead to runtime errors or incorrect behavior.

Suggested change
*attn_out_shape = *q_shape;
attn_out_shape->SetDimNum(2);
attn_out_shape->SetDim(0, q_shape->GetDim(0));
attn_out_shape->SetDim(1, q_shape->GetDim(1) * q_shape->GetDim(3));


kv_caches_shape->SetDimNum(q_shape->GetDimNum());
kv_caches_shape->SetDim(0, q_shape->GetDim(0));
kv_caches_shape->SetDim(1, q_shape->GetDim(1));
kv_caches_shape->SetDim(DIM_2, q_shape->GetDim(DIM_3));
kv_caches_shape->SetDim(DIM_3, q_shape->GetDim(DIM_3));

return ge::GRAPH_SUCCESS;
}

static ge::graphStatus InferDataTypeLightningAttentionDecode(gert::InferDataTypeContext *context)
{
const auto inputDataType = context->GetInputDataType(INDEX_IN_Q);
context->SetOutputDataType(INDEX_OUT_ATTN, inputDataType);
context->SetOutputDataType(INDEX_OUT_KV_CACHES, inputDataType);
return ge::GRAPH_SUCCESS;
}

IMPL_OP_INFERSHAPE(LightningAttentionDecode)
.InferShape(InferShapeLightningAttentionDecode)
.InferDataType(InferDataTypeLightningAttentionDecode);

}
Loading
Loading