@@ -14,6 +14,16 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include " npu_process_group.h"
17+ #ifdef TORCH_HIGHER_THAN_PTA6
18+ #include < torch_npu/csrc/framework/OpCommand.h>
19+ #else
20+ #include < torch_npu/csrc/aten/NPUNativeFunctions.h>
21+ #include < torch_npu/csrc/framework/utils/OpPreparation.h>
22+ #endif
23+
24+ #include < c10d/ProcessGroup.hpp>
25+ #include < c10d/TCPStore.hpp>
26+ #include < torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
1727
1828namespace {
1929
@@ -24,113 +34,65 @@ namespace {
2434 LOG (FATAL) << " Failed, HCCL error :" << HcclGetErrorString (r); \
2535 } \
2636 } while (0 )
37+ } // namespace
2738
28- inline bool is_npu (const at::Tensor& tensor) {
29- if (!tensor.defined ()) {
30- return false ;
31- }
32- return tensor.device ().is_privateuseone ();
33- }
34-
35- inline bool is_npu (const at::TensorOptions& options) {
36- return options.device ().is_privateuseone ();
37- }
39+ namespace xllm {
3840
39- inline bool is_npu (const at::Device& device) {
40- return device.is_privateuseone ();
41- }
41+ ProcessGroupHCCL::ProcessGroupHCCL (int global_rank,
42+ int world_size,
43+ int rank_size,
44+ int port,
45+ bool trans,
46+ const std::string& host,
47+ const std::string& group_name,
48+ const torch::Device& device)
49+ : ProcessGroup(device) {
50+ c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL::Options> hccl_pg_options =
51+ c10d_npu::ProcessGroupHCCL::Options::create ();
52+ // hccl_pg_options->group_name = group_name;
53+ int rank = global_rank;
54+ if (world_size != rank_size) {
55+ auto [local_rank, group_ranks] =
56+ get_group_rank (world_size, global_rank, rank_size, trans);
57+ std::vector<uint32_t > uint32_ranks;
58+ for (auto rank : group_ranks) {
59+ uint32_ranks.push_back (static_cast <uint32_t >(rank));
60+ }
61+ hccl_pg_options->global_ranks_in_group = uint32_ranks;
62+ rank = local_rank;
63+ }
4264
43- at::Tensor flatten_for_scatter_gather (std::vector<at::Tensor>& tensors) {
44- auto & t = tensors[0 ];
45- std::vector<int64_t > sizes{static_cast <int64_t >(tensors.size ())};
46- sizes.insert (sizes.end (), t.sizes ().begin (), t.sizes ().end ());
47- return at::empty (sizes, t.options ());
65+ auto store = create_tcp_store (host, port, rank);
66+ pg_ = std::make_unique<c10d_npu::ProcessGroupHCCL>(
67+ store, rank, rank_size, hccl_pg_options);
4868}
4969
50- HcclDataType to_hccl_data_type (const torch::Tensor& input) {
51- const auto type = input.scalar_type ();
52- switch (type) {
53- case at::kFloat :
54- return HCCL_DATA_TYPE_FP32;
55- case at::kHalf :
56- return HCCL_DATA_TYPE_FP16;
57- case at::kDouble :
58- return HCCL_DATA_TYPE_FP64;
59- case at::kLong :
60- return HCCL_DATA_TYPE_INT64;
61- case at::kInt :
62- return HCCL_DATA_TYPE_INT32;
63- case at::kChar :
64- return HCCL_DATA_TYPE_INT8;
65- case at::kByte :
66- return HCCL_DATA_TYPE_UINT8;
67- case at::kBool :
68- return HCCL_DATA_TYPE_UINT8;
69- case at::kBFloat16 :
70- return HCCL_DATA_TYPE_BFP16;
71- default :
72- LOG (FATAL) << " Unconvertible HCCL type: " << type;
70+ // Destructor.
71+ ProcessGroupHCCL::~ProcessGroupHCCL () {
72+ if (pg_) {
73+ pg_->shutdown ();
74+ } else {
75+ HCCLCHECK (HcclCommDestroy (comm_));
7376 }
7477}
7578
76- void check_input (torch::Tensor input) {
77- CHECK (is_npu (input)) << " input should be npu tensor" ;
78- CHECK (input.is_contiguous ()) << " input should be contiguous" ;
79- CHECK (!input.is_sparse ()) << " input have to be npu dense tensor" ;
80- }
81-
82- } // namespace
83-
84- namespace xllm {
85-
8679ProcessGroupHCCL::ProcessGroupHCCL (int rank,
8780 int world_size,
8881 const torch::Device& device,
8982 HcclComm comm)
9083 : ProcessGroup(device), comm_(comm) {}
91- // Destructor.
92- ProcessGroupHCCL::~ProcessGroupHCCL () { HCCLCHECK (HcclCommDestroy (comm_)); }
9384
94- void ProcessGroupHCCL::allreduce (torch::Tensor& input) {
95- DCHECK (input.device () == device ())
96- << " input should be on the same device as the process group" ;
97- check_input (input);
98- // inplace all reduce
99- // const auto count = input.numel();
100- // const auto data_type = to_hccl_data_type(input);
101- // auto stream = c10_npu::getCurrentNPUStream();
102- // torch::DeviceGuard device_guard(device());
103- // HCCLCHECK(HcclAllReduce(
104- // /*sendbuff=*/input.data_ptr(),
105- // /*recvbuff=*/input.data_ptr(),
106- // /*count=*/count,
107- // /*datatype=*/data_type,
108- // /*op=*/HCCL_REDUCE_SUM,
109- // /*comm=*/comm_,
110- // /*stream=*/stream));
111- }
112- void ProcessGroupHCCL::allgather (const torch::Tensor& input,
113- std::vector<torch::Tensor>& outputs) {
114- check_input (input);
115- // CHECK(outputs.size() == world_size())
116- // << "outputs should have the same size as world_size";
117- // DCHECK(input.device() == device())
118- // << "input should be on the same device as the process group";
119- // torch::DeviceGuard device_guard(device());
120- // torch::Tensor flattened_output = flatten_for_scatter_gather(outputs);
121- // const auto count = input.numel();
122- // const auto data_type = to_hccl_data_type(input);
123- // auto stream = c10_npu::getCurrentNPUStream();
124- // HCCLCHECK(HcclAllGather(
125- // /*sendbuff=*/input.data_ptr(),
126- // /*recvbuff=*/flattened_output.data_ptr(),
127- // /*sendcount=*/count,
128- // /*datatype=*/data_type,
129- // /*comm=*/comm_,
130- // /*stream=*/stream));
131- // // copy the flattened output tensors to the outputs.
132- // for (int i = 0; i < outputs.size(); ++i) {
133- // outputs[i].copy_(flattened_output[i], /*non_blocking=*/true);
134- // }
85+ std::unique_ptr<xllm::ProcessGroup> create_process_group (
86+ int rank,
87+ int world_size,
88+ int rank_size,
89+ int port,
90+ bool trans,
91+ const std::string& host,
92+ const std::string& group_name,
93+ const torch::Device& device) {
94+ return std::make_unique<ProcessGroupHCCL>(
95+ rank, world_size, rank_size, port, trans, host, group_name, device);
13596}
97+
13698} // namespace xllm
0 commit comments