@@ -263,9 +263,7 @@ public:
263263 cta_m, cta_n
264264 );
265265 }
266-
267- // TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
268- #if CUTLASS_VERSION != 351
266+ // Kernel helper function to get next work ID
269267 template <class WorkIdPipeline , class WorkIdPipelineState >
270268 CUTLASS_DEVICE
271269 auto
@@ -280,18 +278,19 @@ public:
280278 // Return true to indicate that the WorkID pipeline state should be advanced
281279 return cute::make_tuple (new_work_tile_info, true );
282280 }
283- #else
284- CUTLASS_DEVICE
285- auto
286- fetch_next_work (WorkTileInfo work_tile_info) {
287- if (continue_current_work (work_tile_info)) {
288- return work_tile_info;
289- }
290281
291- advance_to_next_work ();
292- return get_current_work ();
282+ CUTLASS_DEVICE
283+ static auto
284+ work_tile_to_cta_coord (WorkTileInfo work_tile_info) {
285+ // Get every cta coord in three dimensions of the cluster
286+ auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster ();
287+ return make_coord (
288+ work_tile_info.M_idx + static_cast <int32_t >(cta_m_in_cluster),
289+ work_tile_info.N_idx + static_cast <int32_t >(cta_n_in_cluster),
290+ _,
291+ work_tile_info.L_idx + static_cast <int32_t >(cta_l_in_cluster)
292+ );
293293 }
294- #endif
295294
296295 // Given the inputs, computes the physical grid we should launch.
297296 template <class ProblemShapeMNKL , class BlockShape , class ClusterShape >
@@ -348,6 +347,20 @@ public:
348347 );
349348 }
350349
350+ // Convert CTA-level work tile info to cluster-level tile coord
351+ CUTLASS_DEVICE
352+ cute::Coord<int ,int ,int ,int >
353+ tile_info_to_coord_mnkl (WorkTileInfo work_tile_info) const {
354+ // TileScheduler works at CTA-level, kernel works at cluster-level
355+ int m_coord = idx2crd (work_tile_info.M_idx / params.cluster_shape_m_ ,
356+ params.problem_tiles_m_ );
357+ int n_coord = idx2crd (work_tile_info.N_idx / params.cluster_shape_n_ ,
358+ params.problem_tiles_n_ );
359+ int l_coord = idx2crd (work_tile_info.L_idx ,
360+ params.problem_tiles_l_ );
361+ return make_coord (m_coord, n_coord, _, l_coord);
362+ }
363+
351364 // Returns whether the block assigned this work should compute the epilogue for the corresponding
352365 // output tile. For the basic tile scheduler, this is always true.
353366 CUTLASS_HOST_DEVICE
@@ -458,7 +471,7 @@ public:
458471 template <class ProblemShape , class ElementAccumulator >
459472 static cutlass::Status
460473 initialize_workspace (Arguments const &, void *, cudaStream_t, ProblemShape, KernelHardwareInfo const &,
461- uint32_t , const uint32_t = 1 , CudaHostAdapter* cuda_adapter = nullptr ) {
474+ uint32_t , const uint32_t = 1 ) {
462475 return Status::kSuccess ;
463476 }
464477public:
0 commit comments