Skip to content

Commit 8571ac4

Browse files
authored
[Kernel] Update CUTLASS to 3.5.1 (#7085)
1 parent 997cf78 commit 8571ac4

File tree

4 files changed

+129
-107
lines changed

4 files changed

+129
-107
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
193193
FetchContent_Declare(
194194
cutlass
195195
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
196-
# CUTLASS 3.5.0
197-
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
196+
# CUTLASS 3.5.1
197+
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
198198
# Shallow clone with depth 1
199199
GIT_SHALLOW TRUE
200200
GIT_PROGRESS TRUE
@@ -237,7 +237,7 @@ define_gpu_extension_target(
237237
SOURCES ${VLLM_EXT_SRC}
238238
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
239239
ARCHITECTURES ${VLLM_GPU_ARCHES}
240-
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
240+
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
241241
USE_SABI 3
242242
WITH_SOABI)
243243

csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp

Lines changed: 111 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,19 @@ using namespace detail;
6464

6565
// Row vector broadcast
6666
template<
67-
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
68-
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
6967
int Stages,
7068
class CtaTileShapeMNK,
7169
class Element,
7270
class StrideMNL = Stride<_0,_1,_0>,
7371
int Alignment = 128 / sizeof_bits_v<Element>
7472
>
7573
struct Sm90RowOrScalarBroadcast {
76-
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
77-
static_assert(
78-
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
79-
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
74+
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
75+
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
76+
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
8077

81-
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
82-
struct SharedStorage {
83-
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
78+
struct SharedStorage {
79+
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
8480
};
8581

8682
// This struct has been modified to have a bool indicating that ptr_row is a
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
10096
return args;
10197
}
10298

99+
template <class ProblemShape>
100+
static bool
101+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
102+
return true;
103+
}
104+
103105
template <class ProblemShape>
104106
static size_t
105107
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
118120

119121
CUTLASS_HOST_DEVICE
120122
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
121-
: params(params),
122-
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
123+
: params(params)
124+
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
123125

124126
Params params;
125-
Element* smem_row;
127+
Element *smem = nullptr;
126128

127129
CUTLASS_DEVICE bool
128130
is_producer_load_needed() const {
129-
return true;
131+
return false;
130132
}
131133

132134
CUTLASS_DEVICE bool
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
139141
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
140142
}
141143

142-
template <int EpiTiles, class GTensor, class STensor>
143-
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
144-
CUTLASS_DEVICE
145-
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
146-
: gRow(cute::forward<GTensor>(gRow)),
147-
sRow(cute::forward<STensor>(sRow)),
148-
params(params) {}
149-
150-
GTensor gRow; // (CTA_M,CTA_N)
151-
STensor sRow; // (CTA_M,CTA_N,PIPE)
152-
Params const& params;
153-
154-
CUTLASS_DEVICE void
155-
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
156-
if (!params.row_broadcast) {
157-
return;
158-
}
159-
160-
if (issue_tma_load) {
161-
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
162-
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
163-
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
164-
// Issue the TMA bulk copy
165-
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
166-
// Filter so we don't issue redundant copies over stride-0 modes
167-
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
168-
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
169-
}
170-
}
171-
};
172-
173144
template <class... Args>
174145
CUTLASS_DEVICE auto
175146
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
176-
177-
auto [M, N, K, L] = args.problem_shape_mnkl;
178-
auto [m, n, k, l] = args.tile_coord_mnkl;
179-
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
180-
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
181-
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
182-
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
183-
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
184-
185-
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
186-
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
187-
cute::move(gRow), cute::move(sRow), params);
147+
return EmptyProducerLoadCallbacks{};
188148
}
189149

190-
template <int EpiTiles, class RTensor, class STensor>
150+
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
191151
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
192152
CUTLASS_DEVICE
193-
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
194-
: tCrRow(cute::forward<RTensor>(tCrRow)),
195-
tCsRow(cute::forward<STensor>(tCsRow)),
196-
params(params) {}
197-
198-
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
199-
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
153+
ConsumerStoreCallbacks(
154+
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
155+
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
156+
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
157+
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
158+
: tGS_gRow(tGS_gRow_)
159+
, tGS_sRow(tGS_sRow_)
160+
, tGS_cRow(tGS_cRow_)
161+
, tiled_G2S(tiled_g2s_)
162+
, tSR_sRow(tSR_sRow_)
163+
, tSR_rRow(tSR_rRow_)
164+
, tCcRow(tCcRow_)
165+
, residue_tCcRow(residue_tCcRow_)
166+
, params(params_) {}
167+
168+
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
169+
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
170+
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
171+
Tiled_G2S tiled_G2S;
172+
173+
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
174+
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
175+
176+
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
177+
ThrResidue residue_tCcRow; // (m, n)
178+
ThrNum thr_num;
200179
Params const& params;
201180

202181
CUTLASS_DEVICE void
203-
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
182+
begin() {
204183
if (!params.row_broadcast) {
205-
fill(tCrRow, *(params.ptr_row));
184+
fill(tSR_rRow, *(params.ptr_row));
206185
return;
207186
}
208187

188+
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
189+
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
190+
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
191+
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
192+
193+
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
194+
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
195+
continue; // OOB of SMEM,
196+
}
197+
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
198+
tGS_sRow_flt(i) = tGS_gRow_flt(i);
199+
}
200+
else {
201+
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
202+
}
203+
}
204+
synchronize();
205+
}
206+
207+
CUTLASS_DEVICE void
208+
begin_loop(int epi_m, int epi_n) {
209209
if (epi_m == 0) { // Assumes M-major subtile loop
210-
// Filter so we don't issue redundant copies over stride-0 modes
211-
// (only works if 0-strides are in same location, which is by construction)
212-
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
213-
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
210+
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
211+
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
212+
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
213+
copy(tSR_sRow_flt, tSR_rRow_flt);
214214
}
215215
}
216216

@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
221221

222222
CUTLASS_PRAGMA_UNROLL
223223
for (int i = 0; i < FragmentSize; ++i) {
224-
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
224+
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
225225
}
226226

227227
return frg_row;
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
234234
>
235235
CUTLASS_DEVICE auto
236236
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
237+
auto [M, N, K, L] = args.problem_shape_mnkl;
238+
auto [m, n, k, l] = args.tile_coord_mnkl;
239+
using ThreadCount = decltype(size(args.tiled_copy));
237240

238-
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
239-
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
240-
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
241-
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
242-
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
243-
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
244-
245-
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
246-
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
247-
cute::move(tCrRow), cute::move(tCsRow), params);
241+
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
242+
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
243+
Tensor sRow = make_tensor(make_smem_ptr(smem),
244+
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
245+
//// G2S: Gmem to Smem
246+
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
247+
Layout< Shape<_1, ThreadCount>,
248+
Stride<_0, _1>>{},
249+
Layout<_1>{});
250+
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
251+
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
252+
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
253+
254+
//// G2S: Coord
255+
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
256+
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
257+
258+
//// S2R: Smem to Reg
259+
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
260+
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
261+
262+
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
263+
tGS_gRow,
264+
tGS_sRow,
265+
tGS_cRow, tiled_g2s,
266+
tSR_sRow,
267+
tSR_rRow,
268+
args.tCcD,
269+
args.residue_cD,
270+
ThreadCount{},
271+
params);
248272
}
249273
};
250274

@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
285309
return args;
286310
}
287311

312+
template <class ProblemShape>
313+
static bool
314+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
315+
return true;
316+
}
317+
288318
template <class ProblemShape>
289319
static size_t
290320
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#include "cute/atom/mma_atom.hpp"
1111
#include "cutlass/numeric_types.h"
1212

13-
#include "cutlass/util/device_memory.h"
14-
1513
#include "cutlass/cutlass.h"
1614
#include "cutlass/gemm_coord.h"
1715
#include "cutlass/arch/mma_sm75.h"
@@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
301299
// Launch the CUTLASS GEMM kernel.
302300
typename Gemm::Op gemm_op;
303301
size_t workspace_size = gemm_op.get_workspace_size(args);
304-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
302+
auto const workspace_options =
303+
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
304+
auto workspace = torch::empty(workspace_size, workspace_options);
305305

306306
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
307307

308308
CUTLASS_CHECK(gemm_op.can_implement(args));
309-
cutlass::Status status = gemm_op(args, workspace.get(), stream);
309+
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
310310
CUTLASS_CHECK(status);
311311
}
312312

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#include "cute/atom/mma_atom.hpp"
1919
#include "cutlass/numeric_types.h"
2020

21-
#include "cutlass/util/device_memory.h"
22-
2321
#include "cutlass/gemm/device/gemm_universal_adapter.h"
2422
#include "cutlass/gemm/kernel/gemm_universal.hpp"
2523
#include "cutlass/epilogue/collective/collective_builder.hpp"
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
7270
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
7371
Stride<Int<1>, Int<0>, Int<0>>>;
7472

75-
using ScaleBDescriptor =
76-
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
77-
EpilogueDescriptor, float>;
78-
7973
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
80-
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
81-
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
74+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
75+
Stride<Int<0>, Int<1>, Int<0>>>;
8276
};
8377

8478
/*
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
154148
cutlass::multiply_add, ElementD, float,
155149
cutlass::FloatRoundStyle::round_to_nearest>;
156150

157-
using BiasDescriptor =
158-
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
159-
EpilogueDescriptor, ElementD>;
160-
161151
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
162-
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
152+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
163153
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
164154

165155
public:
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
251241
int64_t ldb = b.stride(1);
252242
int64_t ldc = out.stride(0);
253243

254-
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
255-
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
244+
using StrideA = Stride<int64_t, Int<1>, int64_t>;
245+
using StrideB = Stride<int64_t, Int<1>, int64_t>;
256246
using StrideC = typename Gemm::StrideC;
257247

258-
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
259-
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
248+
StrideA a_stride{lda, Int<1>{}, 0};
249+
StrideB b_stride{ldb, Int<1>{}, 0};
260250
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
261251

262252
using GemmKernel = typename Gemm::GemmKernel;
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
282272
CUTLASS_CHECK(gemm_op.can_implement(args));
283273

284274
size_t workspace_size = gemm_op.get_workspace_size(args);
285-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
275+
auto const workspace_options =
276+
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
277+
auto workspace = torch::empty(workspace_size, workspace_options);
286278

287279
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
288280

289-
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
281+
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
290282
CUTLASS_CHECK(status);
291283
}
292284

0 commit comments

Comments
 (0)