Skip to content

Commit b1e2a27

Browse files
yingxudengliutongxuan
authored andcommitted
feat: add NPU process group initialization and management.
1 parent e045db1 commit b1e2a27

File tree

10 files changed

+113
-119
lines changed

10 files changed

+113
-119
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ if(USE_NPU)
344344
$ENV{PYTORCH_INSTALL_PATH}/include
345345
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
346346
$ENV{PYTORCH_NPU_INSTALL_PATH}/include
347+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed
347348
$ENV{NPU_HOME_PATH}/include
348349
$ENV{ATB_HOME_PATH}/include
349350
$ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/

xllm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p
3434
add_dependencies(xllm brpc-static)
3535

3636
if(USE_NPU)
37-
set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext)
37+
set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext torch_npu torch_python)
3838
elseif(USE_MLU)
3939
set(COMMON_LIBS Python::Python)
4040
endif()

xllm/core/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cc_library(
3030
absl::random_random
3131
absl::strings
3232
torch
33+
$<$<BOOL:${USE_NPU}>:torch_python>
3334
$<$<BOOL:${USE_NPU}>:torch_npu>
3435
$<$<BOOL:${USE_MSPTI}>:mspti>
3536
$<$<BOOL:${USE_NPU}>:ms_tools_ext>

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,9 @@ DEFINE_bool(enable_constrained_decoding,
464464
"Whether to enable constrained decoding, which is used to ensure "
465465
"that the output meets specific format or structural requirements "
466466
"through pre-defined rules.");
467+
468+
469+
DEFINE_string(
470+
npu_kernel_backend,
471+
"ATB",
472+
"NPU kernel backend. Supported options: ATB, TORCH. Default is ATB.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,5 @@ DECLARE_int64(dit_cache_skip_interval_steps);
226226
DECLARE_double(dit_cache_residual_diff_threshold);
227227

228228
DECLARE_bool(enable_constrained_decoding);
229+
230+
DECLARE_string(npu_kernel_backend);

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ void WorkerServer::create_server(
104104

105105
CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size);
106106
const ParallelArgs* parallel_args = comm.parallel_args();
107-
#if defined(USE_MLU) || defined(USE_CUDA)
108107
comm.create_process_groups(master_node_addr, device);
109-
#endif
110108

111109
std::unique_ptr<Worker> worker =
112110
std::make_unique<Worker>(*parallel_args, device, options, worker_type);

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "mapping_npu.h"
1919

2020
#if defined(USE_NPU)
21+
#include "npu_process_group.h"
2122
#include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h"
2223
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"
2324
#include "xllm_kernels/models/base/param/mapping.h"
@@ -30,23 +31,6 @@ limitations under the License.
3031
#include "parallel_args.h"
3132
#include "util/net.h"
3233

33-
namespace {
34-
#if defined(USE_NPU)
35-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
36-
int rank,
37-
int world_size,
38-
int rank_size,
39-
int port,
40-
bool trans,
41-
const std::string& host,
42-
const std::string& group_name,
43-
const torch::Device& device) {
44-
LOG(FATAL) << "Unsupported device type";
45-
return nullptr;
46-
}
47-
#endif
48-
} // namespace
49-
5034
namespace xllm {
5135

5236
CollectiveCommunicator::CollectiveCommunicator(int global_rank,
@@ -72,6 +56,13 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
7256
// std::make_unique<ProcessGroupHCCL>(
7357
// global_rank, world_size, device, comm);
7458

59+
// comunicator will be inited in torch.
60+
if (FLAGS_npu_kernel_backend == "TORCH") {
61+
parallel_args_ = std::make_unique<ParallelArgs>(
62+
global_rank, world_size, dp_size, nullptr, ep_size);
63+
return;
64+
}
65+
7566
// comunicator will be inited in atb.
7667
MappingNPU::Options mapping_options;
7768
mapping_options.dp_size(dp_size)
@@ -116,6 +107,11 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
116107
void CollectiveCommunicator::create_process_groups(
117108
const std::string& master_addr,
118109
const torch::Device& device) {
110+
#if defined(USE_NPU)
111+
if (FLAGS_npu_kernel_backend == "ATB") {
112+
return;
113+
}
114+
#endif
119115
std::string host;
120116
int port;
121117
net::parse_host_port_from_addr(master_addr, host, port);

xllm/core/framework/parallel_state/npu_process_group.cpp

Lines changed: 56 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "npu_process_group.h"
17+
#ifdef TORCH_HIGHER_THAN_PTA6
18+
#include <torch_npu/csrc/framework/OpCommand.h>
19+
#else
20+
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
21+
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
22+
#endif
23+
24+
#include <c10d/ProcessGroup.hpp>
25+
#include <c10d/TCPStore.hpp>
26+
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
1727

1828
namespace {
1929

@@ -24,113 +34,65 @@ namespace {
2434
LOG(FATAL) << "Failed, HCCL error :" << HcclGetErrorString(r); \
2535
} \
2636
} while (0)
37+
} // namespace
2738

28-
inline bool is_npu(const at::Tensor& tensor) {
29-
if (!tensor.defined()) {
30-
return false;
31-
}
32-
return tensor.device().is_privateuseone();
33-
}
34-
35-
inline bool is_npu(const at::TensorOptions& options) {
36-
return options.device().is_privateuseone();
37-
}
39+
namespace xllm {
3840

39-
inline bool is_npu(const at::Device& device) {
40-
return device.is_privateuseone();
41-
}
41+
ProcessGroupHCCL::ProcessGroupHCCL(int global_rank,
42+
int world_size,
43+
int rank_size,
44+
int port,
45+
bool trans,
46+
const std::string& host,
47+
const std::string& group_name,
48+
const torch::Device& device)
49+
: ProcessGroup(device) {
50+
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL::Options> hccl_pg_options =
51+
c10d_npu::ProcessGroupHCCL::Options::create();
52+
// hccl_pg_options->group_name = group_name;
53+
int rank = global_rank;
54+
if (world_size != rank_size) {
55+
auto [local_rank, group_ranks] =
56+
get_group_rank(world_size, global_rank, rank_size, trans);
57+
std::vector<uint32_t> uint32_ranks;
58+
for (auto rank : group_ranks) {
59+
uint32_ranks.push_back(static_cast<uint32_t>(rank));
60+
}
61+
hccl_pg_options->global_ranks_in_group = uint32_ranks;
62+
rank = local_rank;
63+
}
4264

43-
at::Tensor flatten_for_scatter_gather(std::vector<at::Tensor>& tensors) {
44-
auto& t = tensors[0];
45-
std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
46-
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
47-
return at::empty(sizes, t.options());
65+
auto store = create_tcp_store(host, port, rank);
66+
pg_ = std::make_unique<c10d_npu::ProcessGroupHCCL>(
67+
store, rank, rank_size, hccl_pg_options);
4868
}
4969

50-
HcclDataType to_hccl_data_type(const torch::Tensor& input) {
51-
const auto type = input.scalar_type();
52-
switch (type) {
53-
case at::kFloat:
54-
return HCCL_DATA_TYPE_FP32;
55-
case at::kHalf:
56-
return HCCL_DATA_TYPE_FP16;
57-
case at::kDouble:
58-
return HCCL_DATA_TYPE_FP64;
59-
case at::kLong:
60-
return HCCL_DATA_TYPE_INT64;
61-
case at::kInt:
62-
return HCCL_DATA_TYPE_INT32;
63-
case at::kChar:
64-
return HCCL_DATA_TYPE_INT8;
65-
case at::kByte:
66-
return HCCL_DATA_TYPE_UINT8;
67-
case at::kBool:
68-
return HCCL_DATA_TYPE_UINT8;
69-
case at::kBFloat16:
70-
return HCCL_DATA_TYPE_BFP16;
71-
default:
72-
LOG(FATAL) << "Unconvertible HCCL type: " << type;
70+
// Destructor.
71+
ProcessGroupHCCL::~ProcessGroupHCCL() {
72+
if (pg_) {
73+
pg_->shutdown();
74+
} else {
75+
HCCLCHECK(HcclCommDestroy(comm_));
7376
}
7477
}
7578

76-
void check_input(torch::Tensor input) {
77-
CHECK(is_npu(input)) << "input should be npu tensor";
78-
CHECK(input.is_contiguous()) << "input should be contiguous";
79-
CHECK(!input.is_sparse()) << "input have to be npu dense tensor";
80-
}
81-
82-
} // namespace
83-
84-
namespace xllm {
85-
8679
ProcessGroupHCCL::ProcessGroupHCCL(int rank,
8780
int world_size,
8881
const torch::Device& device,
8982
HcclComm comm)
9083
: ProcessGroup(device), comm_(comm) {}
91-
// Destructor.
92-
ProcessGroupHCCL::~ProcessGroupHCCL() { HCCLCHECK(HcclCommDestroy(comm_)); }
9384

94-
void ProcessGroupHCCL::allreduce(torch::Tensor& input) {
95-
DCHECK(input.device() == device())
96-
<< "input should be on the same device as the process group";
97-
check_input(input);
98-
// inplace all reduce
99-
// const auto count = input.numel();
100-
// const auto data_type = to_hccl_data_type(input);
101-
// auto stream = c10_npu::getCurrentNPUStream();
102-
// torch::DeviceGuard device_guard(device());
103-
// HCCLCHECK(HcclAllReduce(
104-
// /*sendbuff=*/input.data_ptr(),
105-
// /*recvbuff=*/input.data_ptr(),
106-
// /*count=*/count,
107-
// /*datatype=*/data_type,
108-
// /*op=*/HCCL_REDUCE_SUM,
109-
// /*comm=*/comm_,
110-
// /*stream=*/stream));
111-
}
112-
void ProcessGroupHCCL::allgather(const torch::Tensor& input,
113-
std::vector<torch::Tensor>& outputs) {
114-
check_input(input);
115-
// CHECK(outputs.size() == world_size())
116-
// << "outputs should have the same size as world_size";
117-
// DCHECK(input.device() == device())
118-
// << "input should be on the same device as the process group";
119-
// torch::DeviceGuard device_guard(device());
120-
// torch::Tensor flattened_output = flatten_for_scatter_gather(outputs);
121-
// const auto count = input.numel();
122-
// const auto data_type = to_hccl_data_type(input);
123-
// auto stream = c10_npu::getCurrentNPUStream();
124-
// HCCLCHECK(HcclAllGather(
125-
// /*sendbuff=*/input.data_ptr(),
126-
// /*recvbuff=*/flattened_output.data_ptr(),
127-
// /*sendcount=*/count,
128-
// /*datatype=*/data_type,
129-
// /*comm=*/comm_,
130-
// /*stream=*/stream));
131-
// // copy the flattened output tensors to the outputs.
132-
// for (int i = 0; i < outputs.size(); ++i) {
133-
// outputs[i].copy_(flattened_output[i], /*non_blocking=*/true);
134-
// }
85+
std::unique_ptr<xllm::ProcessGroup> create_process_group(
86+
int rank,
87+
int world_size,
88+
int rank_size,
89+
int port,
90+
bool trans,
91+
const std::string& host,
92+
const std::string& group_name,
93+
const torch::Device& device) {
94+
return std::make_unique<ProcessGroupHCCL>(
95+
rank, world_size, rank_size, port, trans, host, group_name, device);
13596
}
97+
13698
} // namespace xllm

xllm/core/framework/parallel_state/npu_process_group.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,30 @@ class ProcessGroupHCCL : public ProcessGroup {
2828
const torch::Device& device,
2929
HcclComm comm);
3030

31+
ProcessGroupHCCL(int rank,
32+
int world_size,
33+
int rank_size,
34+
int port,
35+
bool trans,
36+
const std::string& host,
37+
const std::string& group_name,
38+
const torch::Device& device);
39+
3140
// Destructor.
3241
~ProcessGroupHCCL() override;
3342

34-
void allreduce(torch::Tensor& input) override;
35-
36-
void allgather(const torch::Tensor& input,
37-
std::vector<torch::Tensor>& outputs) override;
38-
3943
private:
4044
HcclComm comm_ = nullptr;
4145
};
4246

47+
std::unique_ptr<xllm::ProcessGroup> create_process_group(
48+
int rank,
49+
int world_size,
50+
int rank_size,
51+
int port,
52+
bool trans,
53+
const std::string& host,
54+
const std::string& group_name,
55+
const torch::Device& device);
56+
4357
} // namespace xllm

xllm/core/framework/parallel_state/process_group.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ limitations under the License.
1919

2020
#include <torch/csrc/distributed/c10d/Backend.hpp>
2121
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
22+
23+
#if defined(USE_NPU)
24+
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
25+
#endif
26+
2227
namespace xllm {
2328
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
2429
int global_rank,
@@ -60,7 +65,16 @@ class ProcessGroup {
6065
torch::Device device_;
6166

6267
protected:
68+
#if defined(USE_NPU) && \
69+
(TORCH_VERSION_MAJOR < 2 || \
70+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 7))
71+
// Using ProcessGroupHCCL for NPU devices
72+
// Note: torch_npu uses an older torch version where c10d::Backend lacks
73+
// shutdown() method
74+
std::unique_ptr<c10d_npu::ProcessGroupHCCL> pg_{nullptr};
75+
#else
6376
std::unique_ptr<c10d::Backend> pg_{nullptr};
77+
#endif
6478
};
6579

6680
} // namespace xllm

0 commit comments

Comments
 (0)