Skip to content

Commit f6ce854

Browse files
author
litangss
committed
add readme & remove unused comment
1 parent d3775f1 commit f6ce854

File tree

8 files changed

+51
-140
lines changed

8 files changed

+51
-140
lines changed

csrc/catcoc/README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
# torch.ops.catcoc
22

33

4-
## Function Description | 功能描述
4+
## Function Description
55

6-
### English:
76
This is the catcoc(based on catlass) version matmul+comm/comm+matmul fused kernel
87

9-
### 中文:
10-
这是调用catcoc模板库(基于catlass)实现的矩阵乘法和通讯融合运算算子
8+
Refs: [CATLSS](https://gitcode.com/cann/catlass) [CATCOC](https://open.codehub.huawei.com/OpenBaize/Ascend/CATCoC)
119

12-
参考/Refs: [CATLSS](https://gitcode.com/cann/catlass) [CATCOC](https://open.codehub.huawei.com/OpenBaize/Ascend/CATCoC)
10+
## Using Cases
11+
### compile support
12+
1. clone catlass in 3rdparty/catlass
13+
2. clone [shmem[coldev]](https://gitee.com/ascend/shmem/tree/coldev/) in 3rdparty and rename examples/templates into 3rdparty/catcoc
14+
3. changing BUILD_CATCOC_OPS=ON in build.sh
15+
4. run 'bash build.sh -a kernels'
16+
17+
### use examples
18+
please check test/python/sgl_kernel_npu/test_catcoc_xxx.py for example
19+
20+
### restrict
21+
1. do not support FP32
22+
2. do not support shmem team
23+
3. do not support dequant fuse(or W8A8)

csrc/catcoc/include/catcoc_host_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ inline at::Tensor get_tiling_tensor(uint32_t &m, uint32_t &n, uint32_t &k, int64
5757
tiling_data->weight_format_mode = weight_format_mode;
5858
tiling_data->data_format_mode = data_format_mode;
5959

60-
// auto tiling_tensor = TorchNpuHelper::CopyTensorHostToDevice(tiling_buffer);
6160
return tiling_buffer;
6261
}
6362

csrc/catcoc/op_host/catcoc_allgather_matmul.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ HOST_API void catcoc_allgather_matmul(const at::Tensor &input_a, const at::Tenso
6666
"tensor type only support half and bf16");
6767

6868
auto formatMode = static_cast<WeightFormatMode>(GetModeVal(weightFormatMap, format_mode, "ND", "format_mode"));
69-
// TORCH_CHECK(formatMode == WeightFormatMode::WEIGHT_ND, "current ops only support weightFormat ND");
7069

7170
uint32_t m = input_a.size(0);
7271
uint32_t k = input_a.size(1);
@@ -77,9 +76,6 @@ HOST_API void catcoc_allgather_matmul(const at::Tensor &input_a, const at::Tenso
7776
auto cpu_tiling_tensor = get_tiling_tensor(m, n, k, formatMode, dTypeMap[aType], blockDim);
7877

7978
auto tiling_data_cpu = reinterpret_cast<KernelCATCOCHostTilingData *>(cpu_tiling_tensor.data_ptr<uint8_t>());
80-
// printf("m is: %d ;", tiling_data_cpu->m);
81-
// printf("n is: %d ;", tiling_data_cpu->n);
82-
// printf("k is: %d ;\n", tiling_data_cpu->k);
8379

8480
int32_t batchIdx = m - 1;
8581
uint32_t tilingSize = sizeof(KernelCATCOCHostTilingData);
@@ -121,19 +117,6 @@ HOST_API void catcoc_allgather_matmul(const at::Tensor &input_a, const at::Tenso
121117
rtGetC2cCtrlAddr(&fftsAddr, &len);
122118
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_tensor.data_ptr());
123119

124-
/*
125-
printf("[host] tiling_ptr on host is %ld\n", tiling_ptr);
126-
printf("[host] ipt_a_ptr is %ld, ipt_b_ptr is %ld, opt_c_ptr is %ld\n", a_ptr, b_ptr, c_ptr);
127-
printf("[host] fftsAddr is %lu, symm_ptr is %lu\n", fftsAddr, symm_ptr);
128-
129-
at::Tensor cpu_tensor = tiling_tensor.to(at::kCPU).contiguous();
130-
uint8_t * data_ptr = cpu_tensor.data_ptr<uint8_t>();
131-
printf("tiling_ptr on host is %ld\n", tiling_ptr);
132-
printf("M element (hex): %02x %02x %02x %02x\n", data_ptr[0], data_ptr[1], data_ptr[2], data_ptr[3]);
133-
printf("N element (hex): %02x %02x %02x %02x\n", data_ptr[4], data_ptr[5], data_ptr[6], data_ptr[7]);
134-
printf("K element (hex): %02x %02x %02x %02x\n", data_ptr[8], data_ptr[9], data_ptr[10], data_ptr[11]);
135-
*/
136-
137120
auto stream = c10_npu::getCurrentNPUStream().stream(false);
138121
auto teamIdx = (uint64_t)teamId;
139122
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
@@ -175,14 +158,6 @@ HOST_API void catcoc_allgather_matmul(const at::Tensor &input_a, const at::Tenso
175158
AT_ERROR("Unknown tiling cases, ops exec failed!");
176159
}
177160
at_npu::native::OpCommand::RunOpApiV2("catcoc_allgather_matmul_kernel", acl_call);
178-
179-
/*
180-
auto teamIdx = (uint64_t)teamId;
181-
uint32_t block_dim = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
182-
// gmWorkspace is a dummy input for ascendc compile with tiling, catcoc ops use gmSymmetric as actual workspace
183-
EXEC_KERNEL_CMD(catcoc_allgather_matmul_kernel, block_dim, fftsAddr, teamIdx, input_a, input_b, output_c,
184-
symm_ptr, workspace_tensor, tiling_ptr);
185-
*/
186161
}
187162

188163
} // namespace npu_kernel

csrc/catcoc/op_host/catcoc_matmul_allreduce.cpp

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ HOST_API void catcoc_matmul_allreduce(const at::Tensor &input_a, const at::Tenso
6565
"tensor type only support half and bf16");
6666

6767
auto formatMode = static_cast<WeightFormatMode>(GetModeVal(weightFormatMap, format_mode, "ND", "format_mode"));
68-
// TORCH_CHECK(formatMode == WeightFormatMode::WEIGHT_ND, "current ops only support weightFormat ND");
6968

7069
uint32_t m = input_a.size(0);
7170
uint32_t k = input_a.size(1);
@@ -76,9 +75,6 @@ HOST_API void catcoc_matmul_allreduce(const at::Tensor &input_a, const at::Tenso
7675
auto cpu_tiling_tensor = get_tiling_tensor(m, n, k, formatMode, dTypeMap[aType], blockDim);
7776

7877
auto tiling_data_cpu = reinterpret_cast<KernelCATCOCHostTilingData *>(cpu_tiling_tensor.data_ptr<uint8_t>());
79-
// printf("m is: %d ;", tiling_data_cpu->m);
80-
// printf("n is: %d ;", tiling_data_cpu->n);
81-
// printf("k is: %d ;\n", tiling_data_cpu->k);
8278

8379
int32_t batchIdx = m - 1;
8480
uint32_t tilingSize = sizeof(KernelCATCOCHostTilingData);
@@ -120,30 +116,10 @@ HOST_API void catcoc_matmul_allreduce(const at::Tensor &input_a, const at::Tenso
120116
rtGetC2cCtrlAddr(&fftsAddr, &len);
121117
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_tensor.data_ptr());
122118

123-
/*
124-
printf("[host] tiling_ptr on host is %ld\n", tiling_ptr);
125-
printf("[host] ipt_a_ptr is %ld, ipt_b_ptr is %ld, opt_c_ptr is %ld\n", a_ptr, b_ptr, c_ptr);
126-
printf("[host] fftsAddr is %lu, symm_ptr is %lu\n", fftsAddr, symm_ptr);
127-
128-
at::Tensor cpu_tensor = tiling_tensor.to(at::kCPU).contiguous();
129-
uint8_t * data_ptr = cpu_tensor.data_ptr<uint8_t>();
130-
printf("tiling_ptr on host is %ld\n", tiling_ptr);
131-
printf("M element (hex): %02x %02x %02x %02x\n", data_ptr[0], data_ptr[1], data_ptr[2], data_ptr[3]);
132-
printf("N element (hex): %02x %02x %02x %02x\n", data_ptr[4], data_ptr[5], data_ptr[6], data_ptr[7]);
133-
printf("K element (hex): %02x %02x %02x %02x\n", data_ptr[8], data_ptr[9], data_ptr[10], data_ptr[11]);
134-
*/
135-
136119
auto stream = c10_npu::getCurrentNPUStream().stream(false);
137120
auto teamIdx = (uint64_t)teamId;
138121
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
139-
/*
140-
auto acl_call = [aicCoreNum, stream, a_ptr, b_ptr, c_ptr, symm_ptr, workspace_ptr, tiling_ptr]() -> int {
141-
printf("tiling_ptr on launch is %ld\n", tiling_ptr);
142-
ACLRT_LAUNCH_KERNEL(catcoc_allgather_matmul_kernel)
143-
(aicCoreNum, stream, a_ptr, b_ptr, c_ptr, symm_ptr, workspace_ptr, tiling_ptr);
144-
return 0;
145-
};
146-
*/
122+
147123
std::function<int()> acl_call;
148124
if ((aType == at::ScalarType::Half) && (formatMode == WeightFormatMode::WEIGHT_ND)) {
149125
acl_call = [aicCoreNum, stream, fftsAddr, teamIdx, a_ptr, b_ptr, c_ptr, symm_ptr, workspace_ptr,
@@ -181,14 +157,6 @@ HOST_API void catcoc_matmul_allreduce(const at::Tensor &input_a, const at::Tenso
181157
AT_ERROR("Unknown tiling cases, ops exec failed!");
182158
}
183159
at_npu::native::OpCommand::RunOpApiV2("catcoc_matmul_allreduce_kernel", acl_call);
184-
185-
/*
186-
auto teamIdx = (uint64_t)teamId;
187-
uint32_t block_dim = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
188-
// gmWorkspace is a dummy input for ascendc compile with tiling, catcoc ops use gmSymmetric as actual workspace
189-
EXEC_KERNEL_CMD(catcoc_allgather_matmul_kernel, block_dim, fftsAddr, teamIdx, input_a, input_b, output_c,
190-
symm_ptr, workspace_tensor, tiling_ptr);
191-
*/
192160
}
193161

194162
} // namespace npu_kernel

csrc/catcoc/ops/op_kernel/catcoc_allgather_matmul_kernel.hpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,6 @@ class CatCocAllgatherMatmul
7474
uint32_t k0 = tiling_data->k0;
7575
uint32_t n0 = tiling_data->n0;
7676

77-
/*
78-
if(rankIdx == 0) {
79-
AscendC::printf("m is: %u ;", tiling_data->m);
80-
AscendC::printf("n is: %u ;", tiling_data->n);
81-
AscendC::printf("k is: %u ;\n", k);
82-
AscendC::printf("rankIdx is %u ; rankSize is %u ; teamIdx is: %d ;\n", rankIdx, rankSize, newTeamIdx);
83-
84-
AscendC::printf("[dev] tiling_ptr on device is %lu \n", (uint64_t) tiling_data);
85-
AscendC::printf("[dev] ipt_a_ptr is %ld, ipt_b_ptr is %ld, opt_c_ptr is %ld\n", gmA, gmB, gmC);
86-
printf("[dev] fftsAddr is %lu, symm_ptr is %lu\n", fftsAddr, (uint64_t) gmSymmetric);
87-
}
88-
*/
89-
90-
/*
91-
uint32_t swizzleOffset = tiling_data->swizzleOffset;
92-
uint32_t swizzleDirect = tiling_data->swizzleDirect;
93-
uint32_t pValue = tiling_data->pValue;
94-
uint32_t commDataSplit = tiling_data->commDataSplit;
95-
uint32_t commNpuSplit = tiling_data->commNpuSplit;
96-
uint32_t ubMoveNum = tiling_data->ubMoveNum;
97-
uint32_t lenPerLoop = tiling_data->lenPerLoop;
98-
*/
99-
10077
// switch cases
10178
using ElementA =
10279
typename std::conditional_t<dMode == DataFormatMode::FP16, half,

csrc/catcoc/ops/op_kernel/catcoc_matmul_allreduce_kernel.hpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,6 @@ class CatCocMatmulAllreduce
7474
uint32_t k0 = tiling_data->k0;
7575
uint32_t n0 = tiling_data->n0;
7676

77-
/*
78-
if(rankIdx == 0) {
79-
AscendC::printf("m is: %u ;", tiling_data->m);
80-
AscendC::printf("n is: %u ;", tiling_data->n);
81-
AscendC::printf("k is: %u ;\n", k);
82-
AscendC::printf("rankIdx is %u ; rankSize is %u ; teamIdx is: %d ;\n", rankIdx, rankSize, newTeamIdx);
83-
84-
AscendC::printf("[dev] tiling_ptr on device is %lu \n", (uint64_t) tiling_data);
85-
AscendC::printf("[dev] ipt_a_ptr is %ld, ipt_b_ptr is %ld, opt_c_ptr is %ld\n", gmA, gmB, gmC);
86-
printf("[dev] fftsAddr is %lu, symm_ptr is %lu\n", fftsAddr, (uint64_t) gmSymmetric);
87-
}
88-
*/
89-
90-
/*
91-
uint32_t swizzleOffset = tiling_data->swizzleOffset;
92-
uint32_t swizzleDirect = tiling_data->swizzleDirect;
93-
uint32_t pValue = tiling_data->pValue;
94-
uint32_t commDataSplit = tiling_data->commDataSplit;
95-
uint32_t commNpuSplit = tiling_data->commNpuSplit;
96-
uint32_t ubMoveNum = tiling_data->ubMoveNum;
97-
uint32_t lenPerLoop = tiling_data->lenPerLoop;
98-
*/
99-
10077
// switch cases
10178
using ElementA =
10279
typename std::conditional_t<dMode == DataFormatMode::FP16, half,

tests/python/sgl_kernel_npu/test_catcoc_allgather_matmul.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,24 @@ def direct_testing(input_a, input_b, input_c, team_id=0, group_list=(), use_nz=F
7070

7171
def shmem_init(rank, world_size):
7272
# original init
73-
# from shmem import set_conf_store_tls
74-
#
75-
# global g_shmem_addr, g_ash_size, g_malloc_size
76-
# set_conf_store_tls(False, "")
77-
# shmem_addr = "tcp://127.0.0.1:26666"
78-
# attributes = ash.InitAttr()
79-
# attributes.my_rank = rank
80-
# attributes.n_ranks = world_size
81-
# attributes.local_mem_size = g_ash_size
82-
# attributes.ip_port = shmem_addr
83-
# attributes.option_attr.data_op_engine_type = ash.OpEngineType.MTE
84-
# ret = ash.shmem_init(attributes)
85-
# assert ret == 0, '[ERROR] aclshmem_init failed'
86-
#
87-
# g_shmem_addr = ash.shmem_malloc(g_malloc_size)
73+
from shmem import set_conf_store_tls
8874

75+
global g_shmem_addr, g_ash_size, g_malloc_size
76+
set_conf_store_tls(False, "")
77+
shmem_addr = "tcp://127.0.0.1:26666"
78+
attributes = ash.InitAttr()
79+
attributes.my_rank = rank
80+
attributes.n_ranks = world_size
81+
attributes.local_mem_size = g_ash_size
82+
attributes.ip_port = shmem_addr
83+
attributes.option_attr.data_op_engine_type = ash.OpEngineType.MTE
84+
ret = ash.shmem_init(attributes)
85+
assert ret == 0, "[ERROR] aclshmem_init failed"
86+
87+
g_shmem_addr = ash.shmem_malloc(g_malloc_size)
88+
89+
90+
def shmem_init_uid(rank, world_size):
8991
# uid init(need env SHMEM_UID_SOCK_IFNAM=enp194s0f0::inet4)
9092
global g_shmem_addr, g_ash_size, g_malloc_size
9193
# 0. disable TLS

tests/python/sgl_kernel_npu/test_catcoc_matmul_allreduce.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,24 @@ def direct_testing(
6262

6363
def shmem_init(rank, world_size):
6464
# original init
65-
# from shmem import set_conf_store_tls
66-
#
67-
# global g_shmem_addr, g_ash_size, g_malloc_size
68-
# set_conf_store_tls(False, "")
69-
# shmem_addr = "tcp://127.0.0.1:26666"
70-
# attributes = ash.InitAttr()
71-
# attributes.my_rank = rank
72-
# attributes.n_ranks = world_size
73-
# attributes.local_mem_size = g_ash_size
74-
# attributes.ip_port = shmem_addr
75-
# attributes.option_attr.data_op_engine_type = ash.OpEngineType.MTE
76-
# ret = ash.shmem_init(attributes)
77-
# assert ret == 0, '[ERROR] aclshmem_init failed'
78-
#
79-
# g_shmem_addr = ash.shmem_malloc(g_malloc_size)
65+
from shmem import set_conf_store_tls
8066

67+
global g_shmem_addr, g_ash_size, g_malloc_size
68+
set_conf_store_tls(False, "")
69+
shmem_addr = "tcp://127.0.0.1:26666"
70+
attributes = ash.InitAttr()
71+
attributes.my_rank = rank
72+
attributes.n_ranks = world_size
73+
attributes.local_mem_size = g_ash_size
74+
attributes.ip_port = shmem_addr
75+
attributes.option_attr.data_op_engine_type = ash.OpEngineType.MTE
76+
ret = ash.shmem_init(attributes)
77+
assert ret == 0, "[ERROR] aclshmem_init failed"
78+
79+
g_shmem_addr = ash.shmem_malloc(g_malloc_size)
80+
81+
82+
def shmem_init_uid(rank, world_size):
8183
# uid init(need env SHMEM_UID_SOCK_IFNAM=enp194s0f0::inet4)
8284
global g_shmem_addr, g_ash_size, g_malloc_size
8385
# 0. disable TLS

0 commit comments

Comments
 (0)