@@ -13,11 +13,14 @@ namespace llm {
1313
1414using namespace cute ;
1515
16- template <class CollectiveMainloop_ , class CollectiveEpilogue_ >
16+ template <class CollectiveMainloop_ ,
17+ class CollectiveEpilogue_ ,
18+ class TileScheduler_ >
1719class Sm80KernelMha {
1820 public:
1921 using CollectiveMainloop = CollectiveMainloop_;
2022 using CollectiveEpilogue = CollectiveEpilogue_;
23+ using TileScheduler = TileScheduler_;
2124
2225 using TiledMma = typename CollectiveMainloop::TiledMma;
2326
@@ -39,45 +42,22 @@ class Sm80KernelMha {
3942 // Kernel params
4043 using MainloopParams = typename CollectiveMainloop::Params;
4144 using EpilogueParams = typename CollectiveEpilogue::Params;
45+ using TileSchedulerParams = typename TileScheduler::Params;
46+
47+ // returns grid and block shape for kernel launch
48+ using TileSchedulerArgs = typename TileScheduler::Arguments;
49+ static dim3 get_grid_shape (TileSchedulerArgs const & args) {
50+ return TileScheduler::get_grid_shape (args);
51+ }
52+ static dim3 get_block_shape () { return kMmaThreads ; }
4253
4354 template <class Params >
44- CUTE_DEVICE void operator ()(const Params& params, char * smem) {
55+ CUTE_DEVICE void operator ()(const Params& params,
56+ const TileSchedulerParams& scheduler_params,
57+ char * smem) {
4558 CollectiveMainloop mha;
4659 CollectiveEpilogue epilogue;
47-
48- const auto tidx = threadIdx .x ;
49-
50- // block coord
51- const int m_block_idx = blockIdx .x ;
52- const int batch_idx = blockIdx .y ;
53- const int kv_head_idx = blockIdx .z ;
54- auto block_coord_mnk = make_coord (m_block_idx, batch_idx, kv_head_idx);
55-
56- // (q_packed_len, HEAD_DIM)
57- MHATile<Params> tile (params, batch_idx, kv_head_idx);
58- auto [Q, O] = tile.template get_qo_tile <Element>();
59- // (kv_len, HEAD_DIM)
60- auto [K, V] = tile.template get_kv_tile <Element>();
61-
62- // problem shape
63- const int q_packed_len = size<0 >(Q);
64- const int kv_len = size<0 >(K);
65- const int head_dim = params.head_dim ;
66- auto problem_shape_mnk = make_shape (q_packed_len, kv_len, head_dim);
67-
68- if (m_block_idx * kBlockM >= q_packed_len) {
69- // m out of bound, return
70- return ;
71- }
72-
73- // (BLK_M, HEAD_DIM)
74- Tensor gQ =
75- local_tile (Q, Shape<BLK_M, HEAD_DIM>{}, make_coord (m_block_idx, _0{}));
76- Tensor gO =
77- local_tile (O, Shape<BLK_M, HEAD_DIM>{}, make_coord (m_block_idx, _0{}));
78- // (BLK_N, HEAD_DIM, n)
79- Tensor gK = local_tile (K, Shape<BLK_N, HEAD_DIM>{}, make_coord (_, _0{}));
80- Tensor gV = local_tile (V, Shape<BLK_N, HEAD_DIM>{}, make_coord (_, _0{}));
60+ TileScheduler scheduler (scheduler_params);
8161
8262 // construct params
8363 MainloopParams mainloop_params{params.sliding_window ,
@@ -88,35 +68,68 @@ class Sm80KernelMha {
8868 params.group_size };
8969 EpilogueParams epilogue_params;
9070
91- TiledMma tiled_mma;
92- // accumulator: MMA,MMA_M,MMA_K)
93- auto tOrAccO = partition_fragment_C (tiled_mma, Shape<BLK_M, HEAD_DIM>{});
94- clear (tOrAccO);
95-
96- constexpr int kRowsPerThr = kRowsPerMMA * size<1 >(tOrAccO);
97- OnlineSoftmax<kRowsPerThr > softmax (params.sm_scale_log2 );
98-
99- // mainloop
100- mha (mainloop_params,
101- gQ ,
102- gK ,
103- gV ,
104- tOrAccO,
105- softmax,
106- tidx,
107- block_coord_mnk,
108- problem_shape_mnk,
109- smem);
110-
111- // epilogue
112- epilogue (epilogue_params,
113- tOrAccO,
114- tiled_mma,
115- gO ,
116- tidx,
117- block_coord_mnk,
118- problem_shape_mnk,
119- smem);
71+ // process each block
72+ for (const auto block_coord : scheduler) {
73+ // block coord: (batch_idx, m_block_idx, kv_head_idx)
74+ const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord;
75+ const auto tidx = threadIdx .x ;
76+
77+ // (q_packed_len, HEAD_DIM)
78+ MHATile<Params> tile (params, batch_idx, kv_head_idx);
79+ auto [Q, O] = tile.template get_qo_tile <Element>();
80+ // (kv_len, HEAD_DIM)
81+ auto [K, V] = tile.template get_kv_tile <Element>();
82+
83+ // problem shape
84+ const int q_packed_len = size<0 >(Q);
85+ const int kv_len = size<0 >(K);
86+ const int head_dim = params.head_dim ;
87+ auto problem_shape_mnk = make_shape (q_packed_len, kv_len, head_dim);
88+
89+ if (m_block_idx * kBlockM >= q_packed_len) {
90+ // m out of bound, skip this block
91+ continue ;
92+ }
93+
94+ // (BLK_M, HEAD_DIM)
95+ Tensor gQ = local_tile (
96+ Q, Shape<BLK_M, HEAD_DIM>{}, make_coord (m_block_idx, _0{}));
97+ Tensor gO = local_tile (
98+ O, Shape<BLK_M, HEAD_DIM>{}, make_coord (m_block_idx, _0{}));
99+ // (BLK_N, HEAD_DIM, n)
100+ Tensor gK = local_tile (K, Shape<BLK_N, HEAD_DIM>{}, make_coord (_, _0{}));
101+ Tensor gV = local_tile (V, Shape<BLK_N, HEAD_DIM>{}, make_coord (_, _0{}));
102+
103+ TiledMma tiled_mma;
104+ // accumulator: MMA,MMA_M,MMA_K)
105+ auto tOrAccO = partition_fragment_C (tiled_mma, Shape<BLK_M, HEAD_DIM>{});
106+ clear (tOrAccO);
107+
108+ constexpr int kRowsPerThr = kRowsPerMMA * size<1 >(tOrAccO);
109+ OnlineSoftmax<kRowsPerThr > softmax (params.sm_scale_log2 );
110+
111+ // mainloop
112+ mha (mainloop_params,
113+ gQ ,
114+ gK ,
115+ gV ,
116+ tOrAccO,
117+ softmax,
118+ tidx,
119+ block_coord,
120+ problem_shape_mnk,
121+ smem);
122+
123+ // epilogue
124+ epilogue (epilogue_params,
125+ tOrAccO,
126+ tiled_mma,
127+ gO ,
128+ tidx,
129+ block_coord,
130+ problem_shape_mnk,
131+ smem);
132+ }
120133 }
121134};
122135
0 commit comments