Skip to content

Commit 778d953

Browse files
Revert "[AsyncMM] re-enable and prepare for cutlass 3.5.1 update (pytorch#144011)"
This reverts commit 24ac873. Reverted pytorch#144011 on behalf of https://github.com/malfet due to Not sure what is going on, but lots of builds are failing ([comment](pytorch#144011 (comment)))
1 parent f4e9aeb commit 778d953

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

torch/csrc/distributed/c10d/cuda/AsyncMM.cu

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
66
#include <c10/cuda/CUDAGuard.h>
77

8-
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
8+
#if false && !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
99
CUDA_VERSION >= 12000
1010
#define BUILD_ASYNC_MM_KERNEL
1111
#endif
1212

1313
#if defined(BUILD_ASYNC_MM_KERNEL)
1414

15-
// TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
16-
#if CUTLASS_VERSION != 351
1715
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy
1816
// loader
1917
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
@@ -58,19 +56,7 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
5856
#include <cute/tensor.hpp>
5957
#undef cuTensorMapEncodeTiled
6058
// Set everything back to normal
61-
// clang-format on
62-
#else
63-
#include <cutlass/core_io.h>
64-
#include <cutlass/cutlass.h>
65-
#include <cutlass/gemm/device/gemm.h>
66-
#include <cutlass/half.h>
67-
#include <cutlass/numeric_types.h>
68-
#include <cutlass/trace.h>
69-
#include <cutlass/util/host_tensor.h>
70-
#include <cute/tensor.hpp>
71-
#endif
7259

73-
#include <cutlass/version.h>
7460
#include <cutlass/gemm/collective/collective_builder.hpp>
7561
#include <cutlass/gemm/device/gemm_universal_adapter.h>
7662
#include <cutlass/epilogue/collective/collective_builder.hpp>
@@ -79,6 +65,7 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
7965
#include <cutlass/gemm/dispatch_policy.hpp>
8066
#include <cutlass/gemm/kernel/gemm_universal.hpp>
8167
#include <cutlass/util/packed_stride.hpp>
68+
// clang-format on
8269

8370
#include <torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh>
8471

@@ -120,7 +107,7 @@ at::Tensor async_input_mm_impl(
120107
cutlass::epilogue::collective::EpilogueTileAuto,
121108
ElementAccumulator,
122109
ElementAccumulator,
123-
ElementC,
110+
void,
124111
LayoutC,
125112
AlignmentC,
126113
ElementC,
@@ -146,7 +133,7 @@ at::Tensor async_input_mm_impl(
146133
KernelSchedule>::CollectiveOp;
147134

148135
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
149-
Shape<int, int, int>,
136+
Shape<int, int, int, int>,
150137
CollectiveMainloop,
151138
CollectiveEpilogue,
152139
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
@@ -184,15 +171,15 @@ at::Tensor async_input_mm_impl(
184171

185172
typename Gemm::Arguments arguments{
186173
cutlass::gemm::GemmUniversalMode::kGemm,
187-
{M, N, K},
174+
{M, N, K, 1},
188175
{
189176
reinterpret_cast<ElementA*>(a.data_ptr<at::BFloat16>()),
190177
stride_A,
191178
reinterpret_cast<ElementB*>(b.data_ptr<at::BFloat16>()),
192179
stride_B,
193180
},
194181
{{1, 1},
195-
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
182+
nullptr,
196183
stride_C,
197184
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
198185
stride_C},

torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ public:
263263
cta_m, cta_n
264264
);
265265
}
266-
267-
// TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
268-
#if CUTLASS_VERSION != 351
266+
// Kernel helper function to get next work ID
269267
template <class WorkIdPipeline, class WorkIdPipelineState>
270268
CUTLASS_DEVICE
271269
auto
@@ -280,18 +278,19 @@ public:
280278
// Return true to indicate that the WorkID pipeline state should be advanced
281279
return cute::make_tuple(new_work_tile_info, true);
282280
}
283-
#else
284-
CUTLASS_DEVICE
285-
auto
286-
fetch_next_work(WorkTileInfo work_tile_info) {
287-
if (continue_current_work(work_tile_info)) {
288-
return work_tile_info;
289-
}
290281

291-
advance_to_next_work();
292-
return get_current_work();
282+
CUTLASS_DEVICE
283+
static auto
284+
work_tile_to_cta_coord(WorkTileInfo work_tile_info) {
285+
// Get every cta coord in three dimensions of the cluster
286+
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster();
287+
return make_coord(
288+
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
289+
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
290+
_,
291+
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
292+
);
293293
}
294-
#endif
295294

296295
// Given the inputs, computes the physical grid we should launch.
297296
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
@@ -348,6 +347,20 @@ public:
348347
);
349348
}
350349

350+
// Convert CTA-level work tile info to cluster-level tile coord
351+
CUTLASS_DEVICE
352+
cute::Coord<int,int,int,int>
353+
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
354+
// TileScheduler works at CTA-level, kernel works at cluster-level
355+
int m_coord = idx2crd(work_tile_info.M_idx / params.cluster_shape_m_,
356+
params.problem_tiles_m_);
357+
int n_coord = idx2crd(work_tile_info.N_idx / params.cluster_shape_n_,
358+
params.problem_tiles_n_);
359+
int l_coord = idx2crd(work_tile_info.L_idx,
360+
params.problem_tiles_l_);
361+
return make_coord(m_coord, n_coord, _, l_coord);
362+
}
363+
351364
// Returns whether the block assigned this work should compute the epilogue for the corresponding
352365
// output tile. For the basic tile scheduler, this is always true.
353366
CUTLASS_HOST_DEVICE
@@ -458,7 +471,7 @@ public:
458471
template <class ProblemShape, class ElementAccumulator>
459472
static cutlass::Status
460473
initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&,
461-
uint32_t, const uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) {
474+
uint32_t, const uint32_t = 1) {
462475
return Status::kSuccess;
463476
}
464477
public:

0 commit comments

Comments
 (0)