Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ option(BUILD_TESTS "build test or not" OFF)
option(BUILD_DEEPEP_MODULE "build deepep" ON)
option(BUILD_KERNELS_MODULE "build kernels" ON)
option(BUILD_CATLASS_MODULE "build catlass ops within kernels" OFF)
option(BUILD_CATCOC_MODULE "build catcoc ops within kernels" ON)

set(CMAKE_CXX_STANDARD 17)
#set(CMAKE_VERBOSE_MAKEFILE ON)
Expand All @@ -24,6 +25,21 @@ if (BUILD_CATLASS_MODULE)
message(STATUS "[CATLASS] ${CATLASS_DIR}")
endif ()

if (BUILD_CATCOC_MODULE)
add_compile_definitions(BUILD_CATCOC_MODULE)
set(CATLASS_DIR "${PROJECT_SOURCE_DIR}/3rdparty/catlass") # specific your catlass path here
set(CATCOC_DIR "${PROJECT_SOURCE_DIR}/3rdparty/catcoc") # specific your catcoc path here
message(STATUS "[CATLASS] ${CATLASS_DIR}")
message(STATUS "[CATCOC] ${CATCOC_DIR}")

# shmem
if(NOT DEFINED SHMEM_HOME_PATH AND NOT DEFINED ENV{SHMEM_HOME_PATH})
message(FATAL_ERROR "Cannot find SHMEM_HOME_PATH, please run set_env.sh.")
elseif(NOT DEFINED SHMEM_HOME_PATH)
set(SHMEM_HOME_PATH $ENV{SHMEM_HOME_PATH})
endif()
endif ()

if (${CMAKE_BUILD_TYPE} MATCHES "RELEASE")
add_compile_options(-O3)
add_compile_options(-fvisibility=hidden -fvisibility-inlines-hidden)
Expand Down
30 changes: 27 additions & 3 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set -e

BUILD_DEEPEP_MODULE="ON"
BUILD_DEEPEP_OPS="ON"
BUILD_CATCOC_OPS="OFF" # need catlass & catcoc in 3rdparty/
BUILD_KERNELS_MODULE="ON"
BUILD_MEMORY_SAVER_MODULE="ON"

Expand Down Expand Up @@ -122,8 +123,8 @@ function build_kernels()
rm -rf $BUILD_DIR
mkdir -p $BUILD_DIR

cmake $COMPILE_OPTIONS -DCMAKE_INSTALL_PREFIX="$OUTPUT_DIR" -DASCEND_HOME_PATH=$ASCEND_HOME_PATH -DASCEND_INCLUDE_DIR=$ASCEND_INCLUDE_DIR -DSOC_VERSION=$SOC_VERSION -DBUILD_DEEPEP_MODULE=$BUILD_DEEPEP_MODULE -DBUILD_KERNELS_MODULE=$BUILD_KERNELS_MODULE -B "$BUILD_DIR" -S .
cmake --build "$BUILD_DIR" -j8 && cmake --build "$BUILD_DIR" --target install
cmake $COMPILE_OPTIONS -DCMAKE_INSTALL_PREFIX="$OUTPUT_DIR" -DASCEND_HOME_PATH=$ASCEND_HOME_PATH -DASCEND_INCLUDE_DIR=$ASCEND_INCLUDE_DIR -DSOC_VERSION=$SOC_VERSION -DBUILD_DEEPEP_MODULE=$BUILD_DEEPEP_MODULE -DBUILD_CATCOC_MODULE=$BUILD_CATCOC_OPS -DBUILD_KERNELS_MODULE=$BUILD_KERNELS_MODULE -B "$BUILD_DIR" -S .
cmake --build "$BUILD_DIR" -j32 && cmake --build "$BUILD_DIR" --target install
cd -
}

Expand Down Expand Up @@ -158,6 +159,29 @@ function build_deepep_kernels()
cd -
}

function build_catcoc_kernels()
{
echo "building catcoc ops..."
# using bisheng to compile catcoc
if [[ "$BUILD_CATCOC_OPS" != "ON" ]]; then return 0; fi

KERNEL_DIR="csrc/catcoc/ops"
CATLASS_DIR="3rdparty/catlass"
CATCOC_DIR="3rdparty/catcoc"

if [ ! -d "$CATLASS_DIR" ]; then
echo "Error: CATLASS Directory '$CATLASS_DIR' does not exist."
fi
if [ ! -d "$CATCOC_DIR" ]; then
echo "Error: CATCOC Directory '$CATCOC_DIR' does not exist."
fi

cd "$KERNEL_DIR" || exit
bash build.sh
cd -

}

function build_memory_saver()
{
if [[ "$BUILD_MEMORY_SAVER_MODULE" != "ON" ]]; then return 0; fi
Expand Down Expand Up @@ -200,7 +224,7 @@ function make_sgl_kernel_npu_package()

function main()
{

build_catcoc_kernels
build_kernels
build_deepep_kernels
if pip3 show wheel;then
Expand Down
27 changes: 27 additions & 0 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ if(BUILD_CATLASS_MODULE)
${PROJECT_OP_SRC_BASE}/catlass/op_host/catlass_matmul_basic.cpp
)
endif()
if(BUILD_CATCOC_MODULE)
# check catcoc kernel already built (must build in bisheng)
set(CATCOC_LIBRARY_OUTPUT ${PROJECT_OP_SRC_BASE}/catcoc/ops/install)
if(NOT EXISTS "${CATCOC_LIBRARY_OUTPUT}/libcatcoc_kernel.so")
message(FATAL_ERROR "Assert Failed: CATCOC ops lib(libcatcoc_kernel.so) does not exist: ${CATCOC_LIBRARY_OUTPUT}")
endif()
list(APPEND OP_SRCS
${PROJECT_OP_SRC_BASE}/catcoc/op_host/catcoc_allgather_matmul.cpp
${PROJECT_OP_SRC_BASE}/catcoc/op_host/catcoc_matmul_allreduce.cpp
)
endif()

# set the so name
set(OP_PLUGIN_NAME sgl_kernel_npu)
Expand Down Expand Up @@ -53,6 +64,7 @@ if(BUILD_CATLASS_MODULE)
${PROJECT_OP_SRC_BASE}/catlass/op_kernel/catlass_matmul_basic_kernel.cpp
)
endif()

ascendc_library(workspace_kernel STATIC ${WORKSPACE_KERNEL_SRCS})
if(BUILD_CATLASS_MODULE)
ascendc_include_directories(workspace_kernel PRIVATE
Expand Down Expand Up @@ -84,6 +96,14 @@ target_link_directories(${OP_PLUGIN_NAME} PRIVATE
${TORCH_DIR}/lib
${TORCH_NPU_DIR}/lib
)
if(BUILD_CATCOC_MODULE)
target_link_libraries(${OP_PLUGIN_NAME} PRIVATE
${SHMEM_HOME_PATH}/shmem/lib/libshmem.so
${CATCOC_LIBRARY_OUTPUT}/libcatcoc_kernel.so
)
# install catcoc lib in wheel
install(FILES ${CATCOC_LIBRARY_OUTPUT}/libcatcoc_kernel.so DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
endif()

target_include_directories(${OP_PLUGIN_NAME} PRIVATE
${PROJECT_OP_SRC_BASE}/utils
Expand All @@ -95,3 +115,10 @@ target_include_directories(${OP_PLUGIN_NAME} PRIVATE
${ASCEND_INCLUDE_DIR}/experiment/platform
${ASCEND_INCLUDE_DIR}/experiment/runtime
)
#if(BUILD_CATCOC_MODULE)
# target_include_directories(${OP_PLUGIN_NAME} PRIVATE
# ${SHMEM_HOME_PATH}/shmem/include
# ${SHMEM_HOME_PATH}/memfabric_hybrid/include/smem/host
# ${SHMEM_HOME_PATH}/memfabric_hybrid/include/smem/device
# )
#endif()
12 changes: 12 additions & 0 deletions csrc/catcoc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# torch.ops.catcoc


## Function Description | 功能描述

### English:
This is the catcoc(based on catlass) version matmul+comm/comm+matmul fused kernel

### 中文:
这是调用catcoc模板库(基于catlass)实现的矩阵乘法和通讯融合运算算子

参考/Refs: [CATLSS](https://gitcode.com/cann/catlass) [CATCOC](https://open.codehub.huawei.com/OpenBaize/Ascend/CATCoC)
47 changes: 47 additions & 0 deletions csrc/catcoc/include/catcoc_host_tiling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KERNEL_CATCOC_HOST_TILING_H
#define KERNEL_CATCOC_HOST_TILING_H

#include <cstdint>
#include <map>

namespace sglang {
namespace npu_kernel {

typedef enum { WEIGHT_ND = 0, WEIGHT_NZ = 1 } WeightFormatMode;

typedef enum { BF16 = 0, FP16 = 1, FP32 = 2 } DataFormatMode;

struct KernelCATCOCHostTilingData {
uint32_t m; // get from matmul M
uint32_t n; // get from matmul N
uint32_t k; // get from matmul K

uint32_t m0 = 128;
uint32_t k0 = 256;
uint32_t n0 = 256;
uint32_t swizzleDirect = 1;
uint32_t swizzleOffset = 7;
uint32_t ubMoveNum = 16 * 1024;
uint32_t pValue = 3;
uint32_t commNpuSplit = 2;
uint32_t commDataSplit = 1;
uint32_t lenPerLoop = 128 * 256 / 2;

int64_t weight_format_mode = WEIGHT_ND;
int64_t data_format_mode = BF16;
};

} // namespace npu_kernel
} // namespace sglang

#endif // KERNEL_CATCOC_HOST_TILING_H
67 changes: 67 additions & 0 deletions csrc/catcoc/include/catcoc_host_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KERNEL_CATCOC_HOST_UTILS_H
#define KERNEL_CATCOC_HOST_UTILS_H

#include <cstdint>
#include "catcoc_host_tiling.h"

namespace sglang {
namespace npu_kernel {

constexpr uint32_t PADDING_BYTE = 32U;

inline std::map<c10::ScalarType, DataFormatMode> dTypeMap = {{at::ScalarType::Half, DataFormatMode::FP16},
{at::ScalarType::BFloat16, DataFormatMode::BF16}};

inline std::unordered_map<c10::string_view, uint16_t> weightFormatMap = {{"ND", WeightFormatMode::WEIGHT_ND},
{"NZ", WeightFormatMode::WEIGHT_NZ}};

// batch size -> memory index
constexpr uint32_t MAX_CAPTURE_NUM = 512;

template <typename MapType>
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
const char *mode_name)
{
std::string modeStr(mode_name);
c10::string_view mode_str = mode_opt.value_or(default_mode);
auto it = mode_map.find(mode_str);
// if input mode is unsupported, use default value
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
return it->second;
}

inline at::Tensor get_tiling_tensor(uint32_t &m, uint32_t &n, uint32_t &k, int64_t weight_format_mode,
int64_t data_format_mode, uint32_t &blockDim)
{
auto ascendc_platform = platform_ascendc::PlatformAscendCManager::GetInstance();
blockDim = static_cast<uint32_t>(ascendc_platform->GetCoreNumAiv());

// align to 32 bytes
int32_t tiling_size = (sizeof(KernelCATCOCHostTilingData) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
auto tiling_buffer = at::empty({tiling_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));

KernelCATCOCHostTilingData *tiling_data = reinterpret_cast<KernelCATCOCHostTilingData *>(tiling_buffer.data_ptr());
tiling_data->m = m;
tiling_data->n = n;
tiling_data->k = k;
tiling_data->weight_format_mode = weight_format_mode;
tiling_data->data_format_mode = data_format_mode;

// auto tiling_tensor = TorchNpuHelper::CopyTensorHostToDevice(tiling_buffer);
return tiling_buffer;
}

} // namespace npu_kernel
} // namespace sglang

#endif // KERNEL_CATCOC_HOST_UTILS_H
48 changes: 48 additions & 0 deletions csrc/catcoc/include/catcoc_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KERNEL_CATCOC_KERNEL_H
#define KERNEL_CATCOC_KERNEL_H

#include <acl/acl.h>
#include "catcoc_host_tiling.h"

void catcoc_allgather_matmul_bf16_wnd_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_allgather_matmul_fp16_wnd_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_allgather_matmul_bf16_wnz_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_allgather_matmul_fp16_wnz_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_matmul_allreduce_bf16_wnd_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_matmul_allreduce_fp16_wnd_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_matmul_allreduce_bf16_wnz_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);

void catcoc_matmul_allreduce_fp16_wnz_kernel(uint32_t blockNum, aclrtStream stream, uint64_t fftsAddr, uint64_t teamIdx,
uint8_t *gmA, uint8_t *gmB, uint8_t *gmC, uint8_t *gmSymmetric,
uint8_t *gmWorkspace, uint8_t *gmTiling);
#endif // KERNEL_CATCOC_KERNEL_H
Loading