-
Notifications
You must be signed in to change notification settings - Fork 974
Expand file tree
/
Copy pathlightning_attention_decode_def.cpp
More file actions
67 lines (62 loc) · 3.16 KB
/
lightning_attention_decode_def.cpp
File metadata and controls
67 lines (62 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/op_def_registry.h"
namespace ops {
class LightningAttentionDecode : public OpDef {
public:
explicit LightningAttentionDecode(const char* name) : OpDef(name)
{
this->Input("query")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("key")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("slope_rate")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("kv_caches")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("slot_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("attention")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("kv_caches")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("input_layout").AttrType(OPTIONAL).String("BNSD");
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(LightningAttentionDecode);
}