Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ BUILD_DEEPEP_MODULE="ON"
BUILD_DEEPEP_OPS="ON"
BUILD_KERNELS_MODULE="ON"
BUILD_MEMORY_SAVER_MODULE="ON"
BUILD_SHMEM_ALLOCATOR_MODULE="OFF"

ONLY_BUILD_DEEPEP_ADAPTER_MODULE="OFF"
ONLY_BUILD_DEEPEP_KERNELs_MODULE="OFF"
ONLY_BUILD_MEMORY_SAVER_MODULE="OFF"
ONLY_BUILD_SHMEM_ALLOCATOR_MODULE="OFF"

DEBUG_MODE="OFF"

Expand All @@ -18,6 +20,7 @@ while getopts ":a:hd" opt; do
BUILD_DEEPEP_MODULE="OFF"
BUILD_KERNELS_MODULE="OFF"
BUILD_MEMORY_SAVER_MODULE="OFF"
BUILD_SHMEM_ALLOCATOR_MODULE="OFF"
case "$OPTARG" in
deepep )
BUILD_DEEPEP_MODULE="ON"
Expand All @@ -42,9 +45,13 @@ while getopts ":a:hd" opt; do
BUILD_MEMORY_SAVER_MODULE="ON"
ONLY_BUILD_MEMORY_SAVER_MODULE="ON"
;;
shmem-allocator )
BUILD_SHMEM_ALLOCATOR_MODULE="ON"
ONLY_BUILD_SHMEM_ALLOCATOR_MODULE="ON"
;;
* )
echo "Error: Invalid Value"
echo "Allowed value: deepep|kernels|deepep-adapter|deepep-kernels|memory-saver"
echo "Allowed value: deepep|kernels|deepep-adapter|deepep-kernels|memory-saver|shmem-allocator"
exit 1
;;
esac
Expand All @@ -61,6 +68,7 @@ while getopts ":a:hd" opt; do
echo " deepep-adapter Only build deepep adapter layer and use old build of deepep kernels."
echo " deepep-kernels Only build deepep kernels and use old build of deepep adapter layer."
echo " memory-saver Only build torch_memory_saver (under contrib)."
echo " shmem-allocator Only build torch-shmem-allocator (under contrib)."
exit 1
;;
\? )
Expand Down Expand Up @@ -113,6 +121,7 @@ function build_kernels()
{
if [[ "$ONLY_BUILD_DEEPEP_KERNELs_MODULE" == "ON" ]]; then return 0; fi
if [[ "$ONLY_BUILD_MEMORY_SAVER_MODULE" == "ON" ]]; then return 0; fi
if [[ "$ONLY_BUILD_SHMEM_ALLOCATOR_MODULE" == "ON" ]]; then return 0; fi

CMAKE_DIR=""
BUILD_DIR="build"
Expand Down Expand Up @@ -172,6 +181,20 @@ function build_memory_saver()
cd -
}

function build_shmem_allocator()
{
if [[ "$BUILD_SHMEM_ALLOCATOR_MODULE" != "ON" ]]; then return 0; fi
echo "[shmem_allocator] Building shmem_allocator via setup.py"
cd contrib/shmem_allocator/python || exit
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/build
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/dist
python3 setup.py clean --all
python3 setup.py bdist_wheel
mv -v "$CURRENT_DIR"/contrib/shmem_allocator/python/dist/shmem_allocator*.whl "${OUTPUT_DIR}/"
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/dist
cd -
}

function make_deepep_package()
{
cd python/deep_ep || exit
Expand Down Expand Up @@ -209,6 +232,7 @@ function main()
pip3 install wheel==0.45.1
fi
build_memory_saver
build_shmem_allocator
if [[ "$BUILD_DEEPEP_MODULE" == "ON" ]]; then
make_deepep_package
fi
Expand Down
82 changes: 82 additions & 0 deletions contrib/shmem_allocator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Torch SHMEM Allocator

A PyTorch pluggable allocator built on [ascend shmem library](https://gitee.com/ascend/shmem).

## Build & Install shmem allocator

- source cann<br>
```bash
# Assuming cann is installed in /usr/local/Ascend/ascend-toolkit
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```

- install shmem library<br>
Please refer to shmem's [README.md](https://gitee.com/ascend/shmem/blob/master/README.md#%E4%B8%89%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B)

- source shmem `set_env.sh` file<br>
```bash
# Assuming shmem is installed in /usr/local/Ascend/shmem
source /usr/local/Ascend/shmem/latest/set_env.sh
```

- build shmem allocator
```bash
# Firstly, change current dir to sgl-kernel-npu project root
cd sgl-kernel-npu
# Then, build shmem allocator by build.sh
bash build.sh -a shmem-allocator
```

- install shmem allocator
```bash
pip install output/shmem_allocator-*.whl
```

## Use shmem allocator in sglang

Shmem allocator provide two python api for users.<br>
- `switch_to_shmem_allocator()`: Switch pytorch's allocator to shmem allocator.

- `init_shmem(my_rank, n_ranks, local_mem_size, meta_size, ip_port)`: Init underlying shmem library.<br>
**Parameters:**
* `my_rank`: rank of current process.
* `n_ranks`: global world size.
* `local_mem_size`: shmem pool size to be pre-allocated on each NPU.
* `meta_size`: the portion of `local_mem_size` that guarantees symmetric allocation across all NPUs.
* `ip_port`: ip:port for inter-PE bootstrap and synchronization.


### Usage example
In sglang's `sglang/srt/managers/scheduler.py`, change npu allocator of scheduler process to shmem allocator as soon as scheduler starts.
```diff
...
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
):
+ from shmem_allocator import switch_to_shmem_allocator
+ switch_to_shmem_allocator()
# Generate the logger prefix
prefix = ""
...
```

In sglang's `sglang/srt/model_executor/model_runner.py`, init shmem in scheduler process exactly after `torch.set_device(idx)` is called.
```diff
...
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")

try:
torch.get_device_module(self.device).set_device(self.gpu_id)
+ from shmem_allocator import init_shmem
+ init_shmem(self.tp_rank, self.tp_size, 39 * (1024 ** 3), 0, 'tcp://127.0.0.1:3366')
except Exception:
...
```
20 changes: 20 additions & 0 deletions contrib/shmem_allocator/csrc/NPUCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "NPUCommon.h"
#include <sstream>
#include <string>

bool OptionsManager::IsHcclZeroCopyEnable = false;
bool OptionsManager::CheckForceUncached = false;

std::string formatErrorCode(int32_t errorCode)
{
// if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) {
// return "";
// }
std::ostringstream oss;
// int deviceIndex = -1;
// c10_npu::GetDevice(&deviceIndex);
// auto rank_id = c10_npu::option::OptionsManager::GetRankId();
oss << "\n[ERROR] CODE" << static_cast<int>(errorCode);

return oss.str();
}
38 changes: 38 additions & 0 deletions contrib/shmem_allocator/csrc/NPUCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <mutex>
#include <atomic>

struct OptionsManager {
static bool IsHcclZeroCopyEnable;
static bool CheckForceUncached;
};

std::string formatErrorCode(int32_t errorCode);

#define PTA_ERROR_MOCK(err_code) formatErrorCode((int32_t)err_code)
#define OPS_ERROR_MOCK(err_code) formatErrorCode((int32_t)err_code)

#define NPU_CHECK_ERROR_MOCK(err_code, ...) \
do { \
int error_code = err_code; \
if ((error_code) != ACL_ERROR_NONE) { \
std::ostringstream oss; \
oss << " NPU function error: [ShmemAllocator Currently do not support detail error log]" << std::endl; \
std::string err_msg = oss.str(); \
ASCEND_LOGE("%s", err_msg.c_str()); \
} \
} while (0)

#define NPU_CHECK_WARN_MOCK(err_code, ...) \
do { \
int error_code = err_code; \
if ((error_code) != ACL_ERROR_NONE) { \
std::ostringstream oss; \
oss << " NPU function warning: [ShmemAllocator Currently do not support detail warning log]" << std::endl; \
std::string err_msg = oss.str(); \
ASCEND_LOGW("%s", err_msg.c_str()); \
} \
} while (0)

const int32_t ACL_SYNC_TIMEOUT = 3600 * 1000; // ms
Loading