Skip to content

Commit e9b2ca4

Browse files
yuxianqsherry-1001
authored andcommitted
[None][feat] Async pp send for PPCommTorch. (NVIDIA#9976)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 9072eea commit e9b2ca4

File tree

4 files changed

+30
-54
lines changed

4 files changed

+30
-54
lines changed

cpp/tensorrt_llm/thop/ncclCommunicatorOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ NcclCommunicatorOp::NcclCommunicatorOp(int64_t worldSize, int64_t rank)
3333

3434
void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
3535
{
36+
tensor.record_stream(at::cuda::getCurrentCUDAStream());
3637
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
3738
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
3839
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
@@ -41,6 +42,7 @@ void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
4142

4243
void NcclCommunicatorOp::recv(th::Tensor& tensor, int64_t fromRank) const
4344
{
45+
tensor.record_stream(at::cuda::getCurrentCUDAStream());
4446
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
4547
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
4648
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};

tensorrt_llm/_torch/device_mesh.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
import torch.distributed as dist
6-
from torch.distributed import get_process_group_ranks
6+
from torch.distributed import ProcessGroup, get_process_group_ranks
77
from torch.distributed.device_mesh import init_device_mesh
88

99
from tensorrt_llm.logger import logger
@@ -48,27 +48,27 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
4848
# Access Torch ProcessGroup
4949
@property
5050
@require_device_mesh
51-
def tp_group_pg(self):
51+
def tp_group_pg(self) -> ProcessGroup:
5252
return self._get_mesh_dim_by_name('tp').get_group()
5353

5454
@property
5555
@require_device_mesh
56-
def pp_group_pg(self):
56+
def pp_group_pg(self) -> ProcessGroup:
5757
return self._get_mesh_dim_by_name('pp').get_group()
5858

5959
@property
6060
@require_device_mesh
61-
def cp_group_pg(self):
61+
def cp_group_pg(self) -> ProcessGroup:
6262
return self._get_mesh_dim_by_name('cp').get_group()
6363

6464
@property
6565
@require_device_mesh
66-
def moe_tp_group_pg(self):
66+
def moe_tp_group_pg(self) -> ProcessGroup:
6767
return self._get_mesh_dim_by_name('moe_tp').get_group()
6868

6969
@property
7070
@require_device_mesh
71-
def moe_ep_group_pg(self):
71+
def moe_ep_group_pg(self) -> ProcessGroup:
7272
return self._get_mesh_dim_by_name('moe_ep').get_group()
7373

7474
# Access rank

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
except Exception:
1717
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
1818

19-
from tensorrt_llm._torch.hostfunc import hostfunc
2019
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
2120
mpi_disabled, mpi_isend, mpi_isend_object,
2221
mpi_recv, mpi_recv_object, mpi_send,
@@ -783,26 +782,16 @@ def pp_broadcast(self, obj, root=0):
783782
return ret[0]
784783

785784

786-
class PPCommBase:
785+
class PPCommNCCL:
787786

788787
def __init__(self, global_mapping: Mapping):
789788
self.mapping = global_mapping
789+
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
790+
self.mapping.world_size,
791+
self.mapping.rank,
792+
)
790793
self.tensor_ready_event = torch.cuda.Event()
791794
self.send_stream = torch.cuda.Stream()
792-
self.tensor_cache = {}
793-
794-
def _cache_tensor(self, tensor: torch.Tensor):
795-
cache_id = id(tensor)
796-
self.tensor_cache[cache_id] = tensor
797-
798-
@hostfunc
799-
def _release_tensor(self, tensor: torch.Tensor):
800-
cache_id = id(tensor)
801-
del self.tensor_cache[cache_id]
802-
803-
@abstractmethod
804-
def direct_send(self, tensor: torch.Tensor, dest: int):
805-
raise NotImplementedError("direct_send is not implemented")
806795

807796
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
808797
if dest is None:
@@ -811,63 +800,47 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
811800
# NCCL send kernel in send_stream cannot be captured,
812801
# so we send in the current stream instead in CUDA graph cases.
813802
if torch.cuda.is_current_stream_capturing():
814-
self.direct_send(tensor, dest)
803+
self.nccl_comm.send(tensor, dest)
815804
return
816805

817806
self.tensor_ready_event.record()
818807
with torch.cuda.stream(self.send_stream):
819808
self.tensor_ready_event.wait()
820-
# tensor may be released before NCCL send finished,
821-
# so we cache it first and release it after send finished.
822-
self._cache_tensor(tensor)
823-
self.direct_send(tensor, dest)
824-
self._release_tensor(tensor)
825-
826-
827-
class PPCommNCCL(PPCommBase):
828-
829-
def __init__(self, global_mapping: Mapping):
830-
super().__init__(global_mapping)
831-
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
832-
self.mapping.world_size,
833-
self.mapping.rank,
834-
)
835-
836-
def direct_send(self, tensor: torch.Tensor, dest: int):
837-
self.nccl_comm.send(tensor, dest)
809+
self.nccl_comm.send(tensor, dest)
838810

839811
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
840812
if src is None:
841813
src = self.mapping.prev_pp_rank()
842814
self.nccl_comm.recv(tensor, src)
843815

844816

845-
class PPCommTorch(PPCommBase):
817+
class PPCommTorch:
846818

847819
def __init__(self, global_mapping: Mapping):
848-
super().__init__(global_mapping)
820+
self.mapping = global_mapping
849821
self.pg = self.mapping.pp_group_pg
850822
self.pg_group = self.mapping.pp_group
851823

852824
def _global_to_local_rank(self, global_rank: int):
853825
assert global_rank in self.pg_group
854826
return self.pg_group.index(global_rank)
855827

856-
def direct_send(self, tensor: torch.Tensor, dest: int):
857-
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
858-
859-
# TODO: support async pp send for PPCommTorch
860828
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
861829
if dest is None:
862830
dest = self.mapping.next_pp_rank()
863831

864-
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
832+
work = self.pg.send([tensor], self._global_to_local_rank(dest), tag=0)
833+
# Send operation cannot be captured without blocking wait,
834+
# so we block the current stream in CUDA graph cases.
835+
if torch.cuda.is_current_stream_capturing():
836+
work.block_current_stream()
865837

866838
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
867839
if src is None:
868840
src = self.mapping.prev_pp_rank()
869841

870-
self.pg.recv([tensor], self._global_to_local_rank(src), tag=0).wait()
842+
work = self.pg.recv([tensor], self._global_to_local_rank(src), tag=0)
843+
work.block_current_stream()
871844

872845

873846
_pp_comm = None

tensorrt_llm/mapping.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List
1717

1818
import torch
19+
from torch.distributed import ProcessGroup
1920

2021
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
2122
from tensorrt_llm._utils import mpi_disabled
@@ -518,23 +519,23 @@ def repurpose_helix_cp_to_tp(self):
518519

519520
# DeviceMesh specific methods
520521
@property
521-
def tp_group_pg(self):
522+
def tp_group_pg(self) -> ProcessGroup:
522523
raise NotImplementedError("tp_group_pg is not implemented.")
523524

524525
@property
525-
def pp_group_pg(self):
526+
def pp_group_pg(self) -> ProcessGroup:
526527
raise NotImplementedError("pp_group_pg is not implemented.")
527528

528529
@property
529-
def cp_group_pg(self):
530+
def cp_group_pg(self) -> ProcessGroup:
530531
raise NotImplementedError("cp_group_pg is not implemented.")
531532

532533
@property
533-
def moe_tp_group_pg(self):
534+
def moe_tp_group_pg(self) -> ProcessGroup:
534535
raise NotImplementedError("moe_tp_group_pg is not implemented.")
535536

536537
@property
537-
def moe_ep_group_pg(self):
538+
def moe_ep_group_pg(self) -> ProcessGroup:
538539
raise NotImplementedError("moe_ep_group_pg is not implemented.")
539540

540541
def build_mesh(self):

0 commit comments

Comments
 (0)