@@ -64,23 +64,19 @@ using namespace detail;
64
64
65
65
// Row vector broadcast
66
66
template <
67
- // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
68
- // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
69
67
int Stages,
70
68
class CtaTileShapeMNK ,
71
69
class Element ,
72
70
class StrideMNL = Stride<_0,_1,_0>,
73
71
int Alignment = 128 / sizeof_bits_v<Element>
74
72
>
75
73
struct Sm90RowOrScalarBroadcast {
76
- static_assert (Alignment * sizeof_bits_v<Element> % 128 == 0 , " sub-16B alignment not supported yet" );
77
- static_assert (
78
- (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
79
- (cute::is_same_v<StrideMNL, Stride<_0,_1,int >>)); // batched row vector broadcast
74
+ static_assert (Stages == 0 , " Row broadcast doesn't support smem usage" );
75
+ static_assert (is_static_v<decltype (take<0 ,2 >(StrideMNL{}))>); // batch stride can be dynamic or static
76
+ static_assert (take<0 ,2 >(StrideMNL{}) == Stride<_0,_1>{});
80
77
81
- // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
82
- struct SharedStorage {
83
- alignas (16 ) array_aligned<Element, size<1 >(CtaTileShapeMNK{}) * Stages> smem_row;
78
+ struct SharedStorage {
79
+ array_aligned<Element, size<1 >(CtaTileShapeMNK{})> smem;
84
80
};
85
81
86
82
// This struct has been modified to have a bool indicating that ptr_row is a
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
100
96
return args;
101
97
}
102
98
99
+ template <class ProblemShape >
100
+ static bool
101
+ can_implement (ProblemShape const & problem_shape, Arguments const & args) {
102
+ return true ;
103
+ }
104
+
103
105
template <class ProblemShape >
104
106
static size_t
105
107
get_workspace_size (ProblemShape const & problem_shape, Arguments const & args) {
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
118
120
119
121
CUTLASS_HOST_DEVICE
120
122
Sm90RowOrScalarBroadcast (Params const & params, SharedStorage const & shared_storage)
121
- : params(params),
122
- smem_row (const_cast <Element*>(shared_storage.smem_row .data())) { }
123
+ : params(params)
124
+ , smem (const_cast <Element*>(shared_storage.smem .data())) { }
123
125
124
126
Params params;
125
- Element* smem_row ;
127
+ Element *smem = nullptr ;
126
128
127
129
CUTLASS_DEVICE bool
128
130
is_producer_load_needed () const {
129
- return true ;
131
+ return false ;
130
132
}
131
133
132
134
CUTLASS_DEVICE bool
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
139
141
return (!params.row_broadcast && *(params.ptr_row ) == Element (0 ));
140
142
}
141
143
142
- template <int EpiTiles, class GTensor , class STensor >
143
- struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
144
- CUTLASS_DEVICE
145
- ProducerLoadCallbacks (GTensor&& gRow , STensor&& sRow , Params const & params)
146
- : gRow (cute::forward<GTensor>(gRow )),
147
- sRow (cute::forward<STensor>(sRow )),
148
- params(params) {}
149
-
150
- GTensor gRow ; // (CTA_M,CTA_N)
151
- STensor sRow ; // (CTA_M,CTA_N,PIPE)
152
- Params const & params;
153
-
154
- CUTLASS_DEVICE void
155
- begin (uint64_t * full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
156
- if (!params.row_broadcast ) {
157
- return ;
158
- }
159
-
160
- if (issue_tma_load) {
161
- // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
162
- constexpr uint32_t copy_bytes = size<1 >(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8 ;
163
- cutlass::arch::ClusterTransactionBarrier::expect_transaction (full_mbarrier_ptr, copy_bytes);
164
- // Issue the TMA bulk copy
165
- auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with (*full_mbarrier_ptr);
166
- // Filter so we don't issue redundant copies over stride-0 modes
167
- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
168
- copy (bulk_copy, filter (gRow ), filter (sRow (_,_,bcast_pipe_index)));
169
- }
170
- }
171
- };
172
-
173
144
template <class ... Args>
174
145
CUTLASS_DEVICE auto
175
146
get_producer_load_callbacks (ProducerLoadArgs<Args...> const & args) {
176
-
177
- auto [M, N, K, L] = args.problem_shape_mnkl ;
178
- auto [m, n, k, l] = args.tile_coord_mnkl ;
179
- Tensor mRow = make_tensor (make_gmem_ptr (params.ptr_row ), make_shape (M,N,L), params.dRow );
180
- Tensor gRow = local_tile (mRow , take<0 ,2 >(args.tile_shape_mnk ), make_coord (m,n,l)); // (CTA_M,CTA_N)
181
- Tensor sRow = make_tensor (make_smem_ptr (smem_row), // (CTA_M,CTA_N,PIPE)
182
- make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{}), Stages),
183
- make_stride (_0{},_1{},size<1 >(CtaTileShapeMNK{})));
184
-
185
- constexpr int EpiTiles = decltype (size<1 >(zipped_divide (make_layout (take<0 ,2 >(args.tile_shape_mnk )), args.epi_tile )))::value;
186
- return ProducerLoadCallbacks<EpiTiles, decltype (gRow ), decltype (sRow )>(
187
- cute::move (gRow ), cute::move (sRow ), params);
147
+ return EmptyProducerLoadCallbacks{};
188
148
}
189
149
190
- template <int EpiTiles , class RTensor , class STensor >
150
+ template <class GS_GTensor , class GS_STensor , class GS_CTensor , class Tiled_G2S , class SR_STensor , class SR_RTensor , class CTensor , class ThrResidue , class ThrNum >
191
151
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
192
152
CUTLASS_DEVICE
193
- ConsumerStoreCallbacks (RTensor&& tCrRow, STensor&& tCsRow, Params const & params)
194
- : tCrRow(cute::forward<RTensor>(tCrRow)),
195
- tCsRow (cute::forward<STensor>(tCsRow)),
196
- params(params) {}
197
-
198
- RTensor tCrRow; // (CPY,CPY_M,CPY_N)
199
- STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
153
+ ConsumerStoreCallbacks (
154
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
155
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
156
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
157
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const & params_)
158
+ : tGS_gRow(tGS_gRow_)
159
+ , tGS_sRow(tGS_sRow_)
160
+ , tGS_cRow(tGS_cRow_)
161
+ , tiled_G2S(tiled_g2s_)
162
+ , tSR_sRow(tSR_sRow_)
163
+ , tSR_rRow(tSR_rRow_)
164
+ , tCcRow(tCcRow_)
165
+ , residue_tCcRow(residue_tCcRow_)
166
+ , params(params_) {}
167
+
168
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
169
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
170
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
171
+ Tiled_G2S tiled_G2S;
172
+
173
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
174
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
175
+
176
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
177
+ ThrResidue residue_tCcRow; // (m, n)
178
+ ThrNum thr_num;
200
179
Params const & params;
201
180
202
181
CUTLASS_DEVICE void
203
- previsit ( int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed ) {
182
+ begin ( ) {
204
183
if (!params.row_broadcast ) {
205
- fill (tCrRow , *(params.ptr_row ));
184
+ fill (tSR_rRow , *(params.ptr_row ));
206
185
return ;
207
186
}
208
187
188
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync (thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
189
+ Tensor tGS_gRow_flt = filter_zeros (tGS_gRow);
190
+ Tensor tGS_sRow_flt = filter_zeros (tGS_sRow);
191
+ Tensor tGS_cRow_flt = make_tensor (tGS_cRow.data (), make_layout (tGS_gRow_flt.shape (), tGS_cRow.stride ()));
192
+
193
+ for (int i = 0 ; i < size (tGS_gRow_flt); ++i) {
194
+ if (get<1 >(tGS_cRow_flt (i)) >= size<1 >(CtaTileShapeMNK{})) {
195
+ continue ; // OOB of SMEM,
196
+ }
197
+ if (elem_less (tGS_cRow_flt (i), make_coord (get<0 >(residue_tCcRow), get<1 >(residue_tCcRow)))) {
198
+ tGS_sRow_flt (i) = tGS_gRow_flt (i);
199
+ }
200
+ else {
201
+ tGS_sRow_flt (i) = Element (0 ); // Set to Zero when OOB so LDS could be issue without any preds.
202
+ }
203
+ }
204
+ synchronize ();
205
+ }
206
+
207
+ CUTLASS_DEVICE void
208
+ begin_loop (int epi_m, int epi_n) {
209
209
if (epi_m == 0 ) { // Assumes M-major subtile loop
210
- // Filter so we don't issue redundant copies over stride-0 modes
211
- // (only works if 0-strides are in same location, which is by construction)
212
- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages ;
213
- copy_aligned ( filter ( tCsRow (_,_,_,epi_m,epi_n,bcast_pipe_index)), filter (tCrRow) );
210
+ if (!params. row_broadcast ) return ; // Do not issue LDS when row is scalar
211
+ Tensor tSR_sRow_flt = filter_zeros ( tSR_sRow (_,_,_,epi_m,epi_n));
212
+ Tensor tSR_rRow_flt = filter_zeros (tSR_rRow) ;
213
+ copy (tSR_sRow_flt, tSR_rRow_flt );
214
214
}
215
215
}
216
216
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
221
221
222
222
CUTLASS_PRAGMA_UNROLL
223
223
for (int i = 0 ; i < FragmentSize; ++i) {
224
- frg_row[i] = tCrRow (epi_v * FragmentSize + i);
224
+ frg_row[i] = tSR_rRow (epi_v * FragmentSize + i);
225
225
}
226
226
227
227
return frg_row;
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
234
234
>
235
235
CUTLASS_DEVICE auto
236
236
get_consumer_store_callbacks (ConsumerStoreArgs<Args...> const & args) {
237
+ auto [M, N, K, L] = args.problem_shape_mnkl ;
238
+ auto [m, n, k, l] = args.tile_coord_mnkl ;
239
+ using ThreadCount = decltype (size (args.tiled_copy ));
237
240
238
- Tensor sRow = make_tensor (make_smem_ptr (smem_row), // (CTA_M,CTA_N,PIPE)
239
- make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{}), Stages),
240
- make_stride (_0{},_1{},size<1 >(CtaTileShapeMNK{})));
241
- Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
242
- sRow , args.epi_tile , args.tiled_copy , args.thread_idx );
243
- Tensor tCrRow = make_tensor_like (take<0 ,3 >(tCsRow)); // (CPY,CPY_M,CPY_N)
244
-
245
- constexpr int EpiTiles = decltype (size<1 >(zipped_divide (make_layout (take<0 ,2 >(args.tile_shape_mnk )), args.epi_tile )))::value;
246
- return ConsumerStoreCallbacks<EpiTiles, decltype (tCrRow), decltype (tCsRow)>(
247
- cute::move (tCrRow), cute::move (tCsRow), params);
241
+ Tensor mRow = make_tensor (make_gmem_ptr (params.ptr_row ), make_shape (M,N,L), params.dRow );
242
+ Tensor gRow = local_tile (mRow (_,_,l), take<0 ,2 >(args.tile_shape_mnk ), make_coord (m, n)); // (CTA_M, CTA_N)
243
+ Tensor sRow = make_tensor (make_smem_ptr (smem),
244
+ make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{})), make_shape (_0{}, _1{})); // (CTA_M, CTA_N)
245
+ // // G2S: Gmem to Smem
246
+ auto tiled_g2s = make_tiled_copy (Copy_Atom<DefaultCopy, Element>{},
247
+ Layout< Shape<_1, ThreadCount>,
248
+ Stride<_0, _1>>{},
249
+ Layout<_1>{});
250
+ auto thr_g2s = tiled_g2s.get_slice (args.thread_idx );
251
+ Tensor tGS_gRow = thr_g2s.partition_S (gRow );
252
+ Tensor tGS_sRow = thr_g2s.partition_D (sRow );
253
+
254
+ // // G2S: Coord
255
+ auto cRow = make_identity_tensor (make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{})));
256
+ Tensor tGS_cRow = thr_g2s.partition_S (cRow);
257
+
258
+ // // S2R: Smem to Reg
259
+ Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow , args.epi_tile , args.tiled_copy , args.thread_idx );
260
+ Tensor tSR_rRow = make_tensor_like (take<0 ,3 >(tSR_sRow)); // (CPY,CPY_M,CPY_N)
261
+
262
+ return ConsumerStoreCallbacks<decltype (tGS_gRow), decltype (tGS_sRow), decltype (tGS_cRow), decltype (tiled_g2s), decltype (tSR_sRow), decltype (tSR_rRow), decltype (args.tCcD ), decltype (args.residue_cD ), ThreadCount>(
263
+ tGS_gRow,
264
+ tGS_sRow,
265
+ tGS_cRow, tiled_g2s,
266
+ tSR_sRow,
267
+ tSR_rRow,
268
+ args.tCcD ,
269
+ args.residue_cD ,
270
+ ThreadCount{},
271
+ params);
248
272
}
249
273
};
250
274
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
285
309
return args;
286
310
}
287
311
312
+ template <class ProblemShape >
313
+ static bool
314
+ can_implement (ProblemShape const & problem_shape, Arguments const & args) {
315
+ return true ;
316
+ }
317
+
288
318
template <class ProblemShape >
289
319
static size_t
290
320
get_workspace_size (ProblemShape const & problem_shape, Arguments const & args) {
0 commit comments