3939 } \
4040 }
4141
42+ #define DISPATCH_WORLD_SIZES_NO_DEFAULT (world_size, ...) \
43+ switch (world_size) { \
44+ INT_SWITCH_CASE (k_world_size, 8 , __VA_ARGS__); \
45+ INT_SWITCH_CASE (k_world_size, 4 , __VA_ARGS__); \
46+ INT_SWITCH_CASE (k_world_size, 2 , __VA_ARGS__); \
47+ default : { \
48+ TORCH_CHECK (false , " Not implemented for world_size=" , world_size); \
49+ } \
50+ }
51+
4252#define DISPATCH_ALIGNMENTS_16_8_4 (alignment, ...) \
4353 switch (alignment) { \
4454 INT_SWITCH_CASE (k_alignment, 16 , __VA_ARGS__); \
@@ -493,6 +503,70 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 512;
493503template <typename T, int alignment, int k_world_size>
494504static __launch_bounds__ (two_shot_all_reduce_max_num_threads) __global__
495505 void two_shot_all_reduce_kernel(
506+ T** input_ptrs,
507+ T* output_ptr,
508+ size_t input_offset,
509+ size_t numel,
510+ uint32_t ** signal_pads,
511+ size_t rank,
512+ size_t world_size) {
513+ static_assert (alignment % sizeof (T) == 0 );
514+ constexpr size_t numel_per_thread = alignment / sizeof (T);
515+
516+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
517+ __syncthreads ();
518+
519+ const size_t numel_per_rank =
520+ at::round_up (numel, alignment * world_size) / world_size;
521+ const size_t start = numel_per_rank * rank;
522+
523+ auto offset = (blockDim .x * blockIdx .x + threadIdx .x ) * numel_per_thread;
524+ auto stride = blockDim .x * gridDim .x * numel_per_thread;
525+ for (size_t i = offset; i < numel_per_rank; i += stride) {
526+ if (start + i >= numel) {
527+ continue ;
528+ }
529+ auto vec = load_and_reduce<T, alignment, k_world_size>(
530+ input_ptrs, rank, world_size, input_offset + start + i);
531+ // store to local buffer
532+ st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
533+ }
534+
535+ __syncthreads ();
536+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
537+ __syncthreads ();
538+ for (size_t i = offset; i < numel_per_rank; i += stride) {
539+ Vec<alignment> tmp[k_world_size];
540+ #pragma unroll k_world_size
541+ for (size_t step = 0 ; step < k_world_size; ++step) {
542+ size_t remote_rank = (rank + step) % k_world_size;
543+ size_t remote_start = numel_per_rank * remote_rank;
544+ if (remote_start + i >= numel) {
545+ continue ;
546+ }
547+ tmp[step] = ld_vec<alignment>(
548+ input_ptrs[remote_rank] + input_offset + remote_start + i);
549+ }
550+ #pragma unroll k_world_size
551+ for (size_t step = 0 ; step < k_world_size; ++step) {
552+ size_t remote_rank = (rank + step) % k_world_size;
553+ size_t remote_start = numel_per_rank * remote_rank;
554+ if (remote_start + i >= numel) {
555+ continue ;
556+ }
557+ st_vec<alignment>(
558+ output_ptr + remote_start + i, tmp[step]);
559+ }
560+ }
561+ // need to make sure all blocks exit simultaneously so that the data
562+ // is not corrupted by the subsequent kernels
563+ __syncthreads ();
564+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
565+ }
566+
567+ template <typename T, int alignment, int k_world_size>
568+ static __launch_bounds__ (two_shot_all_reduce_max_num_threads) __global__
569+ void two_shot_all_reduce_kernel_inplace(
496570 T** input_ptrs,
497571 size_t input_offset,
498572 size_t numel,
@@ -528,8 +602,9 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
528602 sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
529603}
530604
531- at::Tensor two_shot_all_reduce_ (
605+ at::Tensor two_shot_all_reduce_impl (
532606 at::Tensor input,
607+ std::optional<at::Tensor> output,
533608 std::string reduce_op,
534609 std::string group_name) {
535610 TORCH_CHECK (
@@ -546,6 +621,14 @@ at::Tensor two_shot_all_reduce_(
546621 const size_t alignment =
547622 get_and_verify_alignment (input, " two_shot_all_reduce" );
548623
624+ if (output.has_value ()) {
625+ const size_t output_alignment =
626+ get_and_verify_alignment (*output, " two_shot_all_reduce" );
627+ TORCH_CHECK (
628+ alignment <= output_alignment,
629+ " two_shot_all_reduce: output alignment must be equal to or larger than input." );
630+ }
631+
549632 int num_blocks = 0 , num_threads = 0 ;
550633 init_elementwise_launch_config (
551634 input.numel (),
@@ -557,30 +640,73 @@ at::Tensor two_shot_all_reduce_(
557640 num_blocks,
558641 num_threads);
559642
560- AT_DISPATCH_FLOAT_AND_BFLOAT16 (
561- input.scalar_type (), " two_shot_all_reduce" , [&]() {
562- DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
563- DISPATCH_WORLD_SIZES (symm_mem->get_world_size (), [&]() {
564- two_shot_all_reduce_kernel<scalar_t , k_alignment, k_world_size>
565- <<<num_blocks,
566- num_threads,
567- 0 ,
568- at::cuda::getCurrentCUDAStream ()>>>(
569- reinterpret_cast <scalar_t **>(
570- symm_mem->get_buffer_ptrs_dev ()),
571- input.storage_offset(),
572- input.numel(),
573- reinterpret_cast<uint32_t**>(
574- symm_mem->get_signal_pad_ptrs_dev ()),
575- symm_mem->get_rank(),
576- symm_mem->get_world_size());
577- C10_CUDA_KERNEL_LAUNCH_CHECK ();
643+ if (!output.has_value ()) {
644+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
645+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
646+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
647+ DISPATCH_WORLD_SIZES (symm_mem->get_world_size (), [&]() {
648+ two_shot_all_reduce_kernel_inplace<
649+ scalar_t ,
650+ k_alignment,
651+ k_world_size>
652+ <<<num_blocks,
653+ num_threads,
654+ 0 ,
655+ at::cuda::getCurrentCUDAStream ()>>>(
656+ reinterpret_cast <scalar_t **>(
657+ symm_mem->get_buffer_ptrs_dev ()),
658+ input.storage_offset(),
659+ input.numel(),
660+ reinterpret_cast<uint32_t**>(
661+ symm_mem->get_signal_pad_ptrs_dev ()),
662+ symm_mem->get_rank(),
663+ symm_mem->get_world_size());
664+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
665+ });
578666 });
579667 });
580- });
581- return input;
668+ return input;
669+ } else {
670+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
671+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
672+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
673+ DISPATCH_WORLD_SIZES_NO_DEFAULT (symm_mem->get_world_size (), [&]() {
674+ two_shot_all_reduce_kernel<scalar_t , k_alignment, k_world_size>
675+ <<<num_blocks,
676+ num_threads,
677+ 0 ,
678+ at::cuda::getCurrentCUDAStream ()>>>(
679+ reinterpret_cast <scalar_t **>(
680+ symm_mem->get_buffer_ptrs_dev ()),
681+ output->data_ptr<scalar_t>(),
682+ input.storage_offset(),
683+ input.numel(),
684+ reinterpret_cast<uint32_t**>(
685+ symm_mem->get_signal_pad_ptrs_dev ()),
686+ symm_mem->get_rank(),
687+ symm_mem->get_world_size());
688+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
689+ });
690+ });
691+ });
692+ return *output;
693+ }
694+ }
695+
696+ at::Tensor two_shot_all_reduce_ (
697+ at::Tensor input,
698+ std::string reduce_op,
699+ std::string group_name) {
700+ return two_shot_all_reduce_impl (input, std::nullopt , reduce_op, group_name);
582701}
583702
703+ at::Tensor two_shot_all_reduce_out (
704+ at::Tensor input,
705+ std::string reduce_op,
706+ std::string group_name,
707+ at::Tensor output) {
708+ return two_shot_all_reduce_impl (input, output, reduce_op, group_name);
709+ }
584710} // namespace
585711#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
586712
@@ -713,6 +839,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
713839 m.impl (" one_shot_all_reduce" , ::one_shot_all_reduce);
714840 m.impl (" one_shot_all_reduce_out" , ::one_shot_all_reduce_out);
715841 m.impl (" two_shot_all_reduce_" , ::two_shot_all_reduce_);
842+ m.impl (" two_shot_all_reduce_out" , ::two_shot_all_reduce_out);
843+
716844 m.impl (" _async_input_mm" , c10d::cuda::detail::async_input_mm);
717845#endif
718846 m.impl (" stream_write_value32_" , ::stream_write_value32_);
0 commit comments