@@ -263,7 +263,9 @@ public:
263263 cta_m, cta_n
264264 );
265265 }
266- // Kernel helper function to get next work ID
266+
267+ // TODO(yifu): remove this once cutlass 3.5.1 upgrade is completed
268+ #if CUTLASS_VERSION != 351
267269 template <class WorkIdPipeline , class WorkIdPipelineState >
268270 CUTLASS_DEVICE
269271 auto
@@ -278,19 +280,18 @@ public:
278280 // Return true to indicate that the WorkID pipeline state should be advanced
279281 return cute::make_tuple (new_work_tile_info, true );
280282 }
281-
283+ # else
282284 CUTLASS_DEVICE
283- static auto
284- work_tile_to_cta_coord (WorkTileInfo work_tile_info) {
285- // Get every cta coord in three dimensions of the cluster
286- auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster ();
287- return make_coord (
288- work_tile_info.M_idx + static_cast <int32_t >(cta_m_in_cluster),
289- work_tile_info.N_idx + static_cast <int32_t >(cta_n_in_cluster),
290- _,
291- work_tile_info.L_idx + static_cast <int32_t >(cta_l_in_cluster)
292- );
285+ auto
286+ fetch_next_work (WorkTileInfo work_tile_info) {
287+ if (continue_current_work (work_tile_info)) {
288+ return work_tile_info;
289+ }
290+
291+ advance_to_next_work ();
292+ return get_current_work ();
293293 }
294+ #endif
294295
295296 // Given the inputs, computes the physical grid we should launch.
296297 template <class ProblemShapeMNKL , class BlockShape , class ClusterShape >
@@ -347,20 +348,6 @@ public:
347348 );
348349 }
349350
350- // Convert CTA-level work tile info to cluster-level tile coord
351- CUTLASS_DEVICE
352- cute::Coord<int ,int ,int ,int >
353- tile_info_to_coord_mnkl (WorkTileInfo work_tile_info) const {
354- // TileScheduler works at CTA-level, kernel works at cluster-level
355- int m_coord = idx2crd (work_tile_info.M_idx / params.cluster_shape_m_ ,
356- params.problem_tiles_m_ );
357- int n_coord = idx2crd (work_tile_info.N_idx / params.cluster_shape_n_ ,
358- params.problem_tiles_n_ );
359- int l_coord = idx2crd (work_tile_info.L_idx ,
360- params.problem_tiles_l_ );
361- return make_coord (m_coord, n_coord, _, l_coord);
362- }
363-
364351 // Returns whether the block assigned this work should compute the epilogue for the corresponding
365352 // output tile. For the basic tile scheduler, this is always true.
366353 CUTLASS_HOST_DEVICE
@@ -471,7 +458,7 @@ public:
471458 template <class ProblemShape , class ElementAccumulator >
472459 static cutlass::Status
473460 initialize_workspace (Arguments const &, void *, cudaStream_t, ProblemShape, KernelHardwareInfo const &,
474- uint32_t , const uint32_t = 1 ) {
461+ uint32_t , const uint32_t = 1 , CudaHostAdapter* cuda_adapter = nullptr ) {
475462 return Status::kSuccess ;
476463 }
477464public:
0 commit comments