@@ -64,23 +64,19 @@ using namespace detail;
6464
6565// Row vector broadcast
6666template <
67- // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
68- // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
6967 int Stages,
7068 class CtaTileShapeMNK ,
7169 class Element ,
7270 class StrideMNL = Stride<_0,_1,_0>,
7371 int Alignment = 128 / sizeof_bits_v<Element>
7472>
7573struct Sm90RowOrScalarBroadcast {
76- static_assert (Alignment * sizeof_bits_v<Element> % 128 == 0 , " sub-16B alignment not supported yet" );
77- static_assert (
78- (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
79- (cute::is_same_v<StrideMNL, Stride<_0,_1,int >>)); // batched row vector broadcast
74+ static_assert (Stages == 0 , " Row broadcast doesn't support smem usage" );
75+ static_assert (is_static_v<decltype (take<0 ,2 >(StrideMNL{}))>); // batch stride can be dynamic or static
76+ static_assert (take<0 ,2 >(StrideMNL{}) == Stride<_0,_1>{});
8077
81- // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
82- struct SharedStorage {
83- alignas (16 ) array_aligned<Element, size<1 >(CtaTileShapeMNK{}) * Stages> smem_row;
78+ struct SharedStorage {
79+ array_aligned<Element, size<1 >(CtaTileShapeMNK{})> smem;
8480 };
8581
8682 // This struct has been modified to have a bool indicating that ptr_row is a
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
10096 return args;
10197 }
10298
99+ template <class ProblemShape >
100+ static bool
101+ can_implement (ProblemShape const & problem_shape, Arguments const & args) {
102+ return true ;
103+ }
104+
103105 template <class ProblemShape >
104106 static size_t
105107 get_workspace_size (ProblemShape const & problem_shape, Arguments const & args) {
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
118120
119121 CUTLASS_HOST_DEVICE
120122 Sm90RowOrScalarBroadcast (Params const & params, SharedStorage const & shared_storage)
121- : params(params),
122- smem_row (const_cast <Element*>(shared_storage.smem_row .data())) { }
123+ : params(params)
124+ , smem (const_cast <Element*>(shared_storage.smem .data())) { }
123125
124126 Params params;
125- Element* smem_row ;
127+ Element *smem = nullptr ;
126128
127129 CUTLASS_DEVICE bool
128130 is_producer_load_needed () const {
129- return true ;
131+ return false ;
130132 }
131133
132134 CUTLASS_DEVICE bool
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
139141 return (!params.row_broadcast && *(params.ptr_row ) == Element (0 ));
140142 }
141143
142- template <int EpiTiles, class GTensor , class STensor >
143- struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
144- CUTLASS_DEVICE
145- ProducerLoadCallbacks (GTensor&& gRow , STensor&& sRow , Params const & params)
146- : gRow (cute::forward<GTensor>(gRow )),
147- sRow (cute::forward<STensor>(sRow )),
148- params(params) {}
149-
150- GTensor gRow ; // (CTA_M,CTA_N)
151- STensor sRow ; // (CTA_M,CTA_N,PIPE)
152- Params const & params;
153-
154- CUTLASS_DEVICE void
155- begin (uint64_t * full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
156- if (!params.row_broadcast ) {
157- return ;
158- }
159-
160- if (issue_tma_load) {
161- // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
162- constexpr uint32_t copy_bytes = size<1 >(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8 ;
163- cutlass::arch::ClusterTransactionBarrier::expect_transaction (full_mbarrier_ptr, copy_bytes);
164- // Issue the TMA bulk copy
165- auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with (*full_mbarrier_ptr);
166- // Filter so we don't issue redundant copies over stride-0 modes
167- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
168- copy (bulk_copy, filter (gRow ), filter (sRow (_,_,bcast_pipe_index)));
169- }
170- }
171- };
172-
173144 template <class ... Args>
174145 CUTLASS_DEVICE auto
175146 get_producer_load_callbacks (ProducerLoadArgs<Args...> const & args) {
176-
177- auto [M, N, K, L] = args.problem_shape_mnkl ;
178- auto [m, n, k, l] = args.tile_coord_mnkl ;
179- Tensor mRow = make_tensor (make_gmem_ptr (params.ptr_row ), make_shape (M,N,L), params.dRow );
180- Tensor gRow = local_tile (mRow , take<0 ,2 >(args.tile_shape_mnk ), make_coord (m,n,l)); // (CTA_M,CTA_N)
181- Tensor sRow = make_tensor (make_smem_ptr (smem_row), // (CTA_M,CTA_N,PIPE)
182- make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{}), Stages),
183- make_stride (_0{},_1{},size<1 >(CtaTileShapeMNK{})));
184-
185- constexpr int EpiTiles = decltype (size<1 >(zipped_divide (make_layout (take<0 ,2 >(args.tile_shape_mnk )), args.epi_tile )))::value;
186- return ProducerLoadCallbacks<EpiTiles, decltype (gRow ), decltype (sRow )>(
187- cute::move (gRow ), cute::move (sRow ), params);
147+ return EmptyProducerLoadCallbacks{};
188148 }
189149
190- template <int EpiTiles , class RTensor , class STensor >
150+ template <class GS_GTensor , class GS_STensor , class GS_CTensor , class Tiled_G2S , class SR_STensor , class SR_RTensor , class CTensor , class ThrResidue , class ThrNum >
191151 struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
192152 CUTLASS_DEVICE
193- ConsumerStoreCallbacks (RTensor&& tCrRow, STensor&& tCsRow, Params const & params)
194- : tCrRow(cute::forward<RTensor>(tCrRow)),
195- tCsRow (cute::forward<STensor>(tCsRow)),
196- params(params) {}
197-
198- RTensor tCrRow; // (CPY,CPY_M,CPY_N)
199- STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
153+ ConsumerStoreCallbacks (
154+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
155+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
156+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
157+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const & params_)
158+ : tGS_gRow(tGS_gRow_)
159+ , tGS_sRow(tGS_sRow_)
160+ , tGS_cRow(tGS_cRow_)
161+ , tiled_G2S(tiled_g2s_)
162+ , tSR_sRow(tSR_sRow_)
163+ , tSR_rRow(tSR_rRow_)
164+ , tCcRow(tCcRow_)
165+ , residue_tCcRow(residue_tCcRow_)
166+ , params(params_) {}
167+
168+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
169+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
170+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
171+ Tiled_G2S tiled_G2S;
172+
173+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
174+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
175+
176+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
177+ ThrResidue residue_tCcRow; // (m, n)
178+ ThrNum thr_num;
200179 Params const & params;
201180
202181 CUTLASS_DEVICE void
203- previsit ( int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed ) {
182+ begin ( ) {
204183 if (!params.row_broadcast ) {
205- fill (tCrRow , *(params.ptr_row ));
184+ fill (tSR_rRow , *(params.ptr_row ));
206185 return ;
207186 }
208187
188+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync (thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
189+ Tensor tGS_gRow_flt = filter_zeros (tGS_gRow);
190+ Tensor tGS_sRow_flt = filter_zeros (tGS_sRow);
191+ Tensor tGS_cRow_flt = make_tensor (tGS_cRow.data (), make_layout (tGS_gRow_flt.shape (), tGS_cRow.stride ()));
192+
193+ for (int i = 0 ; i < size (tGS_gRow_flt); ++i) {
194+ if (get<1 >(tGS_cRow_flt (i)) >= size<1 >(CtaTileShapeMNK{})) {
195+ continue ; // OOB of SMEM,
196+ }
197+ if (elem_less (tGS_cRow_flt (i), make_coord (get<0 >(residue_tCcRow), get<1 >(residue_tCcRow)))) {
198+ tGS_sRow_flt (i) = tGS_gRow_flt (i);
199+ }
200+ else {
201+ tGS_sRow_flt (i) = Element (0 ); // Set to Zero when OOB so LDS could be issue without any preds.
202+ }
203+ }
204+ synchronize ();
205+ }
206+
207+ CUTLASS_DEVICE void
208+ begin_loop (int epi_m, int epi_n) {
209209 if (epi_m == 0 ) { // Assumes M-major subtile loop
210- // Filter so we don't issue redundant copies over stride-0 modes
211- // (only works if 0-strides are in same location, which is by construction)
212- int bcast_pipe_index = (load_iteration / EpiTiles) % Stages ;
213- copy_aligned ( filter ( tCsRow (_,_,_,epi_m,epi_n,bcast_pipe_index)), filter (tCrRow) );
210+ if (!params. row_broadcast ) return ; // Do not issue LDS when row is scalar
211+ Tensor tSR_sRow_flt = filter_zeros ( tSR_sRow (_,_,_,epi_m,epi_n));
212+ Tensor tSR_rRow_flt = filter_zeros (tSR_rRow) ;
213+ copy (tSR_sRow_flt, tSR_rRow_flt );
214214 }
215215 }
216216
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
221221
222222 CUTLASS_PRAGMA_UNROLL
223223 for (int i = 0 ; i < FragmentSize; ++i) {
224- frg_row[i] = tCrRow (epi_v * FragmentSize + i);
224+ frg_row[i] = tSR_rRow (epi_v * FragmentSize + i);
225225 }
226226
227227 return frg_row;
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
234234 >
235235 CUTLASS_DEVICE auto
236236 get_consumer_store_callbacks (ConsumerStoreArgs<Args...> const & args) {
237+ auto [M, N, K, L] = args.problem_shape_mnkl ;
238+ auto [m, n, k, l] = args.tile_coord_mnkl ;
239+ using ThreadCount = decltype (size (args.tiled_copy ));
237240
238- Tensor sRow = make_tensor (make_smem_ptr (smem_row), // (CTA_M,CTA_N,PIPE)
239- make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{}), Stages),
240- make_stride (_0{},_1{},size<1 >(CtaTileShapeMNK{})));
241- Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
242- sRow , args.epi_tile , args.tiled_copy , args.thread_idx );
243- Tensor tCrRow = make_tensor_like (take<0 ,3 >(tCsRow)); // (CPY,CPY_M,CPY_N)
244-
245- constexpr int EpiTiles = decltype (size<1 >(zipped_divide (make_layout (take<0 ,2 >(args.tile_shape_mnk )), args.epi_tile )))::value;
246- return ConsumerStoreCallbacks<EpiTiles, decltype (tCrRow), decltype (tCsRow)>(
247- cute::move (tCrRow), cute::move (tCsRow), params);
241+ Tensor mRow = make_tensor (make_gmem_ptr (params.ptr_row ), make_shape (M,N,L), params.dRow );
242+ Tensor gRow = local_tile (mRow (_,_,l), take<0 ,2 >(args.tile_shape_mnk ), make_coord (m, n)); // (CTA_M, CTA_N)
243+ Tensor sRow = make_tensor (make_smem_ptr (smem),
244+ make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{})), make_shape (_0{}, _1{})); // (CTA_M, CTA_N)
245+ // // G2S: Gmem to Smem
246+ auto tiled_g2s = make_tiled_copy (Copy_Atom<DefaultCopy, Element>{},
247+ Layout< Shape<_1, ThreadCount>,
248+ Stride<_0, _1>>{},
249+ Layout<_1>{});
250+ auto thr_g2s = tiled_g2s.get_slice (args.thread_idx );
251+ Tensor tGS_gRow = thr_g2s.partition_S (gRow );
252+ Tensor tGS_sRow = thr_g2s.partition_D (sRow );
253+
254+ // // G2S: Coord
255+ auto cRow = make_identity_tensor (make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{})));
256+ Tensor tGS_cRow = thr_g2s.partition_S (cRow);
257+
258+ // // S2R: Smem to Reg
259+ Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow , args.epi_tile , args.tiled_copy , args.thread_idx );
260+ Tensor tSR_rRow = make_tensor_like (take<0 ,3 >(tSR_sRow)); // (CPY,CPY_M,CPY_N)
261+
262+ return ConsumerStoreCallbacks<decltype (tGS_gRow), decltype (tGS_sRow), decltype (tGS_cRow), decltype (tiled_g2s), decltype (tSR_sRow), decltype (tSR_rRow), decltype (args.tCcD ), decltype (args.residue_cD ), ThreadCount>(
263+ tGS_gRow,
264+ tGS_sRow,
265+ tGS_cRow, tiled_g2s,
266+ tSR_sRow,
267+ tSR_rRow,
268+ args.tCcD ,
269+ args.residue_cD ,
270+ ThreadCount{},
271+ params);
248272 }
249273};
250274
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
285309 return args;
286310 }
287311
312+ template <class ProblemShape >
313+ static bool
314+ can_implement (ProblemShape const & problem_shape, Arguments const & args) {
315+ return true ;
316+ }
317+
288318 template <class ProblemShape >
289319 static size_t
290320 get_workspace_size (ProblemShape const & problem_shape, Arguments const & args) {
0 commit comments