1616except Exception :
1717 MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
1818
19- from tensorrt_llm ._torch .hostfunc import hostfunc
2019from tensorrt_llm ._utils import (mpi_allgather , mpi_barrier , mpi_comm ,
2120 mpi_disabled , mpi_isend , mpi_isend_object ,
2221 mpi_recv , mpi_recv_object , mpi_send ,
@@ -783,26 +782,16 @@ def pp_broadcast(self, obj, root=0):
783782 return ret [0 ]
784783
785784
786- class PPCommBase :
785+ class PPCommNCCL :
787786
788787 def __init__ (self , global_mapping : Mapping ):
789788 self .mapping = global_mapping
789+ self .nccl_comm = torch .classes .trtllm .NcclCommunicatorOp (
790+ self .mapping .world_size ,
791+ self .mapping .rank ,
792+ )
790793 self .tensor_ready_event = torch .cuda .Event ()
791794 self .send_stream = torch .cuda .Stream ()
792- self .tensor_cache = {}
793-
794- def _cache_tensor (self , tensor : torch .Tensor ):
795- cache_id = id (tensor )
796- self .tensor_cache [cache_id ] = tensor
797-
798- @hostfunc
799- def _release_tensor (self , tensor : torch .Tensor ):
800- cache_id = id (tensor )
801- del self .tensor_cache [cache_id ]
802-
803- @abstractmethod
804- def direct_send (self , tensor : torch .Tensor , dest : int ):
805- raise NotImplementedError ("direct_send is not implemented" )
806795
807796 def send (self , tensor : torch .Tensor , dest : Optional [int ] = None ):
808797 if dest is None :
@@ -811,63 +800,47 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
811800 # NCCL send kernel in send_stream cannot be captured,
812801 # so we send in the current stream instead in CUDA graph cases.
813802 if torch .cuda .is_current_stream_capturing ():
814- self .direct_send (tensor , dest )
803+ self .nccl_comm . send (tensor , dest )
815804 return
816805
817806 self .tensor_ready_event .record ()
818807 with torch .cuda .stream (self .send_stream ):
819808 self .tensor_ready_event .wait ()
820- # tensor may be released before NCCL send finished,
821- # so we cache it first and release it after send finished.
822- self ._cache_tensor (tensor )
823- self .direct_send (tensor , dest )
824- self ._release_tensor (tensor )
825-
826-
827- class PPCommNCCL (PPCommBase ):
828-
829- def __init__ (self , global_mapping : Mapping ):
830- super ().__init__ (global_mapping )
831- self .nccl_comm = torch .classes .trtllm .NcclCommunicatorOp (
832- self .mapping .world_size ,
833- self .mapping .rank ,
834- )
835-
836- def direct_send (self , tensor : torch .Tensor , dest : int ):
837- self .nccl_comm .send (tensor , dest )
809+ self .nccl_comm .send (tensor , dest )
838810
839811 def recv (self , tensor : torch .Tensor , src : Optional [int ] = None ):
840812 if src is None :
841813 src = self .mapping .prev_pp_rank ()
842814 self .nccl_comm .recv (tensor , src )
843815
844816
845- class PPCommTorch ( PPCommBase ) :
817+ class PPCommTorch :
846818
847819 def __init__ (self , global_mapping : Mapping ):
848- super (). __init__ ( global_mapping )
820+ self . mapping = global_mapping
849821 self .pg = self .mapping .pp_group_pg
850822 self .pg_group = self .mapping .pp_group
851823
852824 def _global_to_local_rank (self , global_rank : int ):
853825 assert global_rank in self .pg_group
854826 return self .pg_group .index (global_rank )
855827
856- def direct_send (self , tensor : torch .Tensor , dest : int ):
857- self .pg .send ([tensor ], self ._global_to_local_rank (dest ), tag = 0 ).wait ()
858-
859- # TODO: support async pp send for PPCommTorch
860828 def send (self , tensor : torch .Tensor , dest : Optional [int ] = None ):
861829 if dest is None :
862830 dest = self .mapping .next_pp_rank ()
863831
864- self .pg .send ([tensor ], self ._global_to_local_rank (dest ), tag = 0 ).wait ()
832+ work = self .pg .send ([tensor ], self ._global_to_local_rank (dest ), tag = 0 )
833+ # Send operation cannot be captured without blocking wait,
834+ # so we block the current stream in CUDA graph cases.
835+ if torch .cuda .is_current_stream_capturing ():
836+ work .block_current_stream ()
865837
866838 def recv (self , tensor : torch .Tensor , src : Optional [int ] = None ):
867839 if src is None :
868840 src = self .mapping .prev_pp_rank ()
869841
870- self .pg .recv ([tensor ], self ._global_to_local_rank (src ), tag = 0 ).wait ()
842+ work = self .pg .recv ([tensor ], self ._global_to_local_rank (src ), tag = 0 )
843+ work .block_current_stream ()
871844
872845
873846_pp_comm = None
0 commit comments