@@ -397,27 +397,34 @@ at::Tensor multimem_all_gather_out(
397397// One-shot all-reduce is register-intensive because it stages values loaded
398398// from peers in registers before performing reduction. Setting the thread
399399// count to 512 to prevent/alleviate register spill.
400- constexpr size_t one_shot_all_reduce_max_num_blocks = 8 ;
400+ constexpr size_t one_shot_all_reduce_max_num_blocks = 24 ;
401401constexpr size_t one_shot_all_reduce_max_num_threads = 512 ;
402402
403403template <typename T, int alignment, int k_world_size>
404404static __launch_bounds__ (one_shot_all_reduce_max_num_threads) __global__
405405 void one_shot_all_reduce_kernel(
406406 T** input_ptrs,
407407 T* output_ptr,
408+ T* input_ptr,
408409 size_t input_offset,
409410 size_t numel,
410411 uint32_t ** signal_pads,
411412 size_t rank,
412413 size_t world_size) {
413414 static_assert (alignment % sizeof (T) == 0 );
414415 constexpr size_t numel_per_thread = alignment / sizeof (T);
415-
416- sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
417- __syncthreads ();
418-
416+ // copy input to shared ptr
419417 auto offset = (blockDim .x * blockIdx .x + threadIdx .x ) * numel_per_thread;
420418 auto stride = blockDim .x * gridDim .x * numel_per_thread;
419+ if (input_ptr) {
420+ for (size_t i = offset; i < numel; i += stride) {
421+ Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
422+ st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
423+ }
424+ }
425+ // TODO make it sync with one block for no-copy case
426+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
427+ __syncthreads ();
421428
422429 for (size_t i = offset; i < numel; i += stride) {
423430 auto vec = load_and_reduce<T, alignment, k_world_size>(
@@ -426,11 +433,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
426433 }
427434
428435 __syncthreads ();
429- sync_remote_blocks<std::memory_order_relaxed >(signal_pads, rank, world_size);
436+ sync_remote_blocks<std::memory_order_acq_rel >(signal_pads, rank, world_size);
430437}
431438
432- at::Tensor one_shot_all_reduce_out (
439+ at::Tensor one_shot_all_reduce_out_impl (
433440 const at::Tensor& input,
441+ const std::optional<at::Tensor>& local_input,
434442 std::string reduce_op,
435443 std::string group_name,
436444 at::Tensor out) {
@@ -440,18 +448,35 @@ at::Tensor one_shot_all_reduce_out(
440448 out.is_contiguous (), " one_shot_all_reduce: output must be contiguous." );
441449 TORCH_CHECK (
442450 out.sizes () == input.sizes (),
443- " one_shot_all_reduce: input/output size mismatch." );
451+ " one_shot_all_reduce: input/output size mismatch, input.sizes(): " ,
452+ input.sizes (),
453+ " , output.sizes(): " ,
454+ out.sizes ());
444455 TORCH_CHECK (
445456 reduce_op == " sum" ,
446457 " one_shot_all_reduce: only sum is supported for now." );
447-
458+ if (local_input.has_value ()) {
459+ TORCH_CHECK (
460+ local_input->is_contiguous (),
461+ " one_shot_all_reduce: local input must be contiguous." );
462+ TORCH_CHECK (
463+ local_input->numel () <= input.numel (),
464+ " one_shot_all_reduce: local input size must be smaller than symm buffer size." );
465+ }
448466 auto symm_mem = c10d::symmetric_memory::rendezvous (input, group_name);
449467 TORCH_CHECK (
450468 symm_mem != nullptr ,
451469 " one_shot_all_reduce: input must be allocated with empty_strided_p2p()." );
452470
453471 const size_t alignment =
454472 get_and_verify_alignment (input, " one_shot_all_reduce" );
473+ if (local_input.has_value ()) {
474+ const size_t local_alignment =
475+ get_and_verify_alignment (*local_input, " one_shot_all_reduce" );
476+ TORCH_CHECK (
477+ alignment == local_alignment,
478+ " one_shot_all_reduce: local input and symm buffer must have the same alignment." );
479+ }
455480
456481 int num_blocks = 0 , num_threads = 0 ;
457482 init_elementwise_launch_config (
@@ -476,6 +501,8 @@ at::Tensor one_shot_all_reduce_out(
476501 reinterpret_cast <scalar_t **>(
477502 symm_mem->get_buffer_ptrs_dev ()),
478503 out.data_ptr<scalar_t>(),
504+ local_input.has_value() ? local_input->data_ptr<scalar_t>()
505+ : nullptr,
479506 input.storage_offset(),
480507 input.numel(),
481508 reinterpret_cast<uint32_t**>(
@@ -489,12 +516,42 @@ at::Tensor one_shot_all_reduce_out(
489516 return out;
490517}
491518
519+ at::Tensor one_shot_all_reduce_out (
520+ const at::Tensor& input,
521+ std::string reduce_op,
522+ std::string group_name,
523+ at::Tensor out) {
524+ return one_shot_all_reduce_out_impl (
525+ input, std::nullopt , reduce_op, group_name, out);
526+ }
527+
528+ at::Tensor one_shot_all_reduce_copy_out (
529+ const at::Tensor& input,
530+ const at::Tensor& local_input,
531+ std::string reduce_op,
532+ std::string group_name,
533+ at::Tensor out) {
534+ return one_shot_all_reduce_out_impl (
535+ input, local_input, reduce_op, group_name, out);
536+ }
537+
492538at::Tensor one_shot_all_reduce (
493539 const at::Tensor& input,
494540 std::string reduce_op,
495541 std::string group_name) {
496542 auto out = at::empty_like (input);
497- return one_shot_all_reduce_out (input, reduce_op, group_name, out);
543+ return one_shot_all_reduce_out_impl (
544+ input, std::nullopt , reduce_op, group_name, out);
545+ }
546+
547+ at::Tensor one_shot_all_reduce_copy (
548+ const at::Tensor& input,
549+ const at::Tensor& local_input,
550+ std::string reduce_op,
551+ std::string group_name) {
552+ auto out = at::empty_like (local_input);
553+ return one_shot_all_reduce_out_impl (
554+ input, local_input, reduce_op, group_name, out);
498555}
499556
500557constexpr size_t two_shot_all_reduce_max_num_blocks = 24 ;
@@ -838,6 +895,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
838895 m.impl (" multimem_all_gather_out" , ::multimem_all_gather_out);
839896 m.impl (" one_shot_all_reduce" , ::one_shot_all_reduce);
840897 m.impl (" one_shot_all_reduce_out" , ::one_shot_all_reduce_out);
898+ m.impl (" one_shot_all_reduce_copy" , ::one_shot_all_reduce_copy);
899+ m.impl (" one_shot_all_reduce_copy_out" , ::one_shot_all_reduce_copy_out);
841900 m.impl (" two_shot_all_reduce_" , ::two_shot_all_reduce_);
842901 m.impl (" two_shot_all_reduce_out" , ::two_shot_all_reduce_out);
843902
0 commit comments