Skip to content

Commit 24ac873

Browse files
Yifu Wangpytorchmergebot
authored andcommitted
[AsyncMM] re-enable and prepare for cutlass 3.5.1 update (pytorch#144011)
Pull Request resolved: pytorch#144011 Approved by: https://github.com/Skylion007, https://github.com/drisspg
1 parent 73a6a40 commit 24ac873

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

torch/csrc/distributed/c10d/cuda/AsyncMM.cu

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
66
#include <c10/cuda/CUDAGuard.h>
77

8-
#if false && !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
8+
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
99
CUDA_VERSION >= 12000
1010
#define BUILD_ASYNC_MM_KERNEL
1111
#endif
1212

1313
#if defined(BUILD_ASYNC_MM_KERNEL)
1414

15+
// TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
16+
#if CUTLASS_VERSION != 351
1517
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy
1618
// loader
1719
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
@@ -56,7 +58,19 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
5658
#include <cute/tensor.hpp>
5759
#undef cuTensorMapEncodeTiled
5860
// Set everything back to normal
61+
// clang-format on
62+
#else
63+
#include <cutlass/core_io.h>
64+
#include <cutlass/cutlass.h>
65+
#include <cutlass/gemm/device/gemm.h>
66+
#include <cutlass/half.h>
67+
#include <cutlass/numeric_types.h>
68+
#include <cutlass/trace.h>
69+
#include <cutlass/util/host_tensor.h>
70+
#include <cute/tensor.hpp>
71+
#endif
5972

73+
#include <cutlass/version.h>
6074
#include <cutlass/gemm/collective/collective_builder.hpp>
6175
#include <cutlass/gemm/device/gemm_universal_adapter.h>
6276
#include <cutlass/epilogue/collective/collective_builder.hpp>
@@ -65,7 +79,6 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
6579
#include <cutlass/gemm/dispatch_policy.hpp>
6680
#include <cutlass/gemm/kernel/gemm_universal.hpp>
6781
#include <cutlass/util/packed_stride.hpp>
68-
// clang-format on
6982

7083
#include <torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh>
7184

@@ -107,7 +120,7 @@ at::Tensor async_input_mm_impl(
107120
cutlass::epilogue::collective::EpilogueTileAuto,
108121
ElementAccumulator,
109122
ElementAccumulator,
110-
void,
123+
ElementC,
111124
LayoutC,
112125
AlignmentC,
113126
ElementC,
@@ -133,7 +146,7 @@ at::Tensor async_input_mm_impl(
133146
KernelSchedule>::CollectiveOp;
134147

135148
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
136-
Shape<int, int, int, int>,
149+
Shape<int, int, int>,
137150
CollectiveMainloop,
138151
CollectiveEpilogue,
139152
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
@@ -171,15 +184,15 @@ at::Tensor async_input_mm_impl(
171184

172185
typename Gemm::Arguments arguments{
173186
cutlass::gemm::GemmUniversalMode::kGemm,
174-
{M, N, K, 1},
187+
{M, N, K},
175188
{
176189
reinterpret_cast<ElementA*>(a.data_ptr<at::BFloat16>()),
177190
stride_A,
178191
reinterpret_cast<ElementB*>(b.data_ptr<at::BFloat16>()),
179192
stride_B,
180193
},
181194
{{1, 1},
182-
nullptr,
195+
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
183196
stride_C,
184197
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
185198
stride_C},

torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ public:
263263
cta_m, cta_n
264264
);
265265
}
266-
// Kernel helper function to get next work ID
266+
267+
// TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
268+
#if CUTLASS_VERSION != 351
267269
template <class WorkIdPipeline, class WorkIdPipelineState>
268270
CUTLASS_DEVICE
269271
auto
@@ -278,19 +280,18 @@ public:
278280
// Return true to indicate that the WorkID pipeline state should be advanced
279281
return cute::make_tuple(new_work_tile_info, true);
280282
}
281-
283+
#else
282284
CUTLASS_DEVICE
283-
static auto
284-
work_tile_to_cta_coord(WorkTileInfo work_tile_info) {
285-
// Get every cta coord in three dimensions of the cluster
286-
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster();
287-
return make_coord(
288-
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
289-
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
290-
_,
291-
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
292-
);
285+
auto
286+
fetch_next_work(WorkTileInfo work_tile_info) {
287+
if (continue_current_work(work_tile_info)) {
288+
return work_tile_info;
289+
}
290+
291+
advance_to_next_work();
292+
return get_current_work();
293293
}
294+
#endif
294295

295296
// Given the inputs, computes the physical grid we should launch.
296297
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
@@ -347,20 +348,6 @@ public:
347348
);
348349
}
349350

350-
// Convert CTA-level work tile info to cluster-level tile coord
351-
CUTLASS_DEVICE
352-
cute::Coord<int,int,int,int>
353-
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
354-
// TileScheduler works at CTA-level, kernel works at cluster-level
355-
int m_coord = idx2crd(work_tile_info.M_idx / params.cluster_shape_m_,
356-
params.problem_tiles_m_);
357-
int n_coord = idx2crd(work_tile_info.N_idx / params.cluster_shape_n_,
358-
params.problem_tiles_n_);
359-
int l_coord = idx2crd(work_tile_info.L_idx,
360-
params.problem_tiles_l_);
361-
return make_coord(m_coord, n_coord, _, l_coord);
362-
}
363-
364351
// Returns whether the block assigned this work should compute the epilogue for the corresponding
365352
// output tile. For the basic tile scheduler, this is always true.
366353
CUTLASS_HOST_DEVICE
@@ -471,7 +458,7 @@ public:
471458
template <class ProblemShape, class ElementAccumulator>
472459
static cutlass::Status
473460
initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&,
474-
uint32_t, const uint32_t = 1) {
461+
uint32_t, const uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) {
475462
return Status::kSuccess;
476463
}
477464
public:

0 commit comments

Comments
 (0)