Skip to content

Commit dd08c28

Browse files
authored
refactor: add _1 into stride for contiguous dim (#466)
1 parent a9ddddf commit dd08c28

14 files changed

+86
-85
lines changed

src/kernels/attention/attn_api.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ void paged_kv_varlen_mha(
3434
// construct attention params
3535
MHAPagedKVParams params;
3636
params.q_ptr = query.const_data_ptr();
37-
params.q_stride = make_stride(query.stride(0), query.stride(1));
37+
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
3838
params.k_ptr = key_cache.const_data_ptr();
39-
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
39+
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
4040
params.v_ptr = value_cache.const_data_ptr();
41-
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
41+
params.v_stride =
42+
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
4243
params.o_ptr = out.mutable_data_ptr();
43-
params.o_stride = make_stride(out.stride(0), out.stride(1));
44+
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
4445
params.alibi_slopes_ptr = alibi_slopes.has_value()
4546
? alibi_slopes.value().const_data_ptr<float>()
4647
: nullptr;

src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ torch::Tensor mha_pagedkv_sm80(
4848
// construct attention params
4949
MHAPagedKVParams params;
5050
params.q_ptr = query.const_data_ptr();
51-
params.q_stride = make_stride(query.stride(0), query.stride(1));
51+
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
5252
params.k_ptr = key_cache.const_data_ptr();
53-
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
53+
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
5454
params.v_ptr = value_cache.const_data_ptr();
55-
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
55+
params.v_stride =
56+
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
5657
params.o_ptr = out.mutable_data_ptr();
57-
params.o_stride = make_stride(out.stride(0), out.stride(1));
58+
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
5859
params.alibi_slopes_ptr = alibi_slopes.has_value()
5960
? alibi_slopes.value().const_data_ptr<float>()
6061
: nullptr;
@@ -243,4 +244,4 @@ INSTANTIATE_TEST_SUITE_P(
243244
::testing::Values(-1, 0, 10) // sliding window
244245
));
245246

246-
} // namespace llm
247+
} // namespace llm

src/kernels/attention/mha_kernel_sm80_test.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ torch::Tensor mha_sm80(
6060
MHAParams params;
6161
params.q_ptr = query.const_data_ptr();
6262
params.q_stride =
63-
make_stride(query.stride(0), query.stride(1), query.stride(2));
63+
make_stride(query.stride(0), query.stride(1), query.stride(2), _1{});
6464
params.k_ptr = key.const_data_ptr();
65-
params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2));
65+
params.k_stride =
66+
make_stride(key.stride(0), key.stride(1), key.stride(2), _1{});
6667
params.v_ptr = value.const_data_ptr();
6768
params.v_stride =
68-
make_stride(value.stride(0), value.stride(1), value.stride(2));
69+
make_stride(value.stride(0), value.stride(1), value.stride(2), _1{});
6970
params.o_ptr = out.mutable_data_ptr();
70-
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
71+
params.o_stride =
72+
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});
7173
params.alibi_slopes_ptr = alibi_slopes.has_value()
7274
? alibi_slopes.value().const_data_ptr<float>()
7375
: nullptr;
@@ -168,4 +170,4 @@ INSTANTIATE_TEST_SUITE_P(
168170
::testing::Values(-1, 0, 10) // sliding window
169171
));
170172

171-
} // namespace llm
173+
} // namespace llm

src/kernels/attention/mha_params.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct MHAParamsCommon {
8282

8383
struct MHAParams : public MHAParamsCommon {
8484
// (batch, seq, head, dim): last dimension is contiguous
85-
using Stride = cute::Stride<int64_t, int64_t, int64_t /*,_1*/>;
85+
using Stride = cute::Stride<int64_t, int64_t, int64_t, cute::_1>;
8686

8787
Stride q_stride;
8888
Stride k_stride;
@@ -97,7 +97,7 @@ struct MHAParams : public MHAParamsCommon {
9797
// paged KV cache + variable length sequence
9898
struct MHAPagedKVParams : public MHAParamsCommon {
9999
// (seq, head, dim): last dimension is contiguous
100-
using Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
100+
using Stride = cute::Stride<int64_t, int64_t, cute::_1>;
101101

102102
Stride q_stride;
103103
Stride k_stride;
@@ -116,4 +116,4 @@ struct MHAPagedKVParams : public MHAParamsCommon {
116116
const int* __restrict__ block_cu_lens = nullptr;
117117
};
118118

119-
} // namespace llm
119+
} // namespace llm

src/kernels/attention/mha_sm80_bench.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,16 @@ void mha_bench_sm80(nvbench::state& state) {
4848
MHAParams params;
4949
params.q_ptr = query.const_data_ptr();
5050
params.q_stride =
51-
make_stride(query.stride(0), query.stride(1), query.stride(2));
51+
make_stride(query.stride(0), query.stride(1), query.stride(2), _1{});
5252
params.k_ptr = key.const_data_ptr();
53-
params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2));
53+
params.k_stride =
54+
make_stride(key.stride(0), key.stride(1), key.stride(2), _1{});
5455
params.v_ptr = value.const_data_ptr();
5556
params.v_stride =
56-
make_stride(value.stride(0), value.stride(1), value.stride(2));
57+
make_stride(value.stride(0), value.stride(1), value.stride(2), _1{});
5758
params.o_ptr = out.mutable_data_ptr();
58-
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
59+
params.o_stride =
60+
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});
5961
params.alibi_slopes_ptr =
6062
alibi ? alibi_slopes.value().const_data_ptr<float>() : nullptr;
6163
params.batch_size = batch_size;

src/kernels/attention/mha_sm80_pagedkv_bench.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,14 @@ void mha_bench_sm80(nvbench::state& state) {
9292
// construct attention params
9393
MHAPagedKVParams params;
9494
params.q_ptr = query.const_data_ptr();
95-
params.q_stride = make_stride(query.stride(0), query.stride(1));
95+
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
9696
params.k_ptr = key_cache.const_data_ptr();
97-
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
97+
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
9898
params.v_ptr = value_cache.const_data_ptr();
99-
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
99+
params.v_stride =
100+
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
100101
params.o_ptr = out.mutable_data_ptr();
101-
params.o_stride = make_stride(out.stride(0), out.stride(1));
102+
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
102103
params.alibi_slopes_ptr =
103104
alibi ? alibi_slopes.value().const_data_ptr<float>() : nullptr;
104105
params.batch_size = batch_size;

src/kernels/attention/mha_tile.h

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,14 @@ struct MHATile<MHAParams> {
4444
auto q = make_gather_tensor(
4545
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
4646
make_shape(packed_len, params_.head_dim),
47-
make_stride(
48-
make_stride(get<1>(params_.q_stride), get<2>(params_.q_stride)),
49-
_1{}),
47+
make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)),
5048
packed_idx_to_coord);
5149

5250
const auto o_offset = batch_idx_ * get<0>(params_.o_stride);
5351
auto o = make_gather_tensor(
5452
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
5553
make_shape(packed_len, params_.head_dim),
56-
make_stride(
57-
make_stride(get<1>(params_.o_stride), get<2>(params_.o_stride)),
58-
_1{}),
54+
make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)),
5955
packed_idx_to_coord);
6056
return make_tuple(q, o);
6157
}
@@ -72,12 +68,12 @@ struct MHATile<MHAParams> {
7268
auto k =
7369
make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset),
7470
make_shape(params_.kv_len, params_.head_dim),
75-
make_stride(get<1>(params_.k_stride), _1{}));
71+
select<1, 3>(params_.k_stride));
7672
// v[batch_idx, :, kv_head_idx, :]
7773
auto v =
7874
make_tensor(make_gmem_ptr((const Element*)params_.v_ptr + v_offset),
7975
make_shape(params_.kv_len, params_.head_dim),
80-
make_stride(get<1>(params_.v_stride), _1{}));
76+
select<1, 3>(params_.v_stride));
8177
return make_tuple(k, v);
8278
}
8379
};
@@ -112,18 +108,14 @@ struct MHATile<MHAPagedKVParams> {
112108
auto q = make_gather_tensor(
113109
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
114110
make_shape(packed_len, params_.head_dim),
115-
make_stride(
116-
make_stride(get<0>(params_.q_stride), get<1>(params_.q_stride)),
117-
_1{}),
111+
make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)),
118112
packed_idx_to_coord);
119113

120114
const auto o_offset = begin * get<0>(params_.o_stride);
121115
auto o = make_gather_tensor(
122116
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
123117
make_shape(packed_len, params_.head_dim),
124-
make_stride(
125-
make_stride(get<0>(params_.o_stride), get<1>(params_.o_stride)),
126-
_1{}),
118+
make_stride(select<0, 1>(params_.o_stride), get<2>(params_.o_stride)),
127119
packed_idx_to_coord);
128120
return make_tuple(q, o);
129121
}
@@ -148,18 +140,18 @@ struct MHATile<MHAPagedKVParams> {
148140
auto k = make_gather_tensor(
149141
make_gmem_ptr((const Element*)params_.k_ptr + k_offset),
150142
make_shape(kv_len, params_.head_dim),
151-
make_stride(get<0>(params_.k_stride), _1{}),
143+
select<0, 2>(params_.k_stride),
152144
idx_to_slot);
153145

154146
// v[:, kv_head_idx, :]
155147
const auto v_offset = kv_head_idx_ * get<1>(params_.v_stride);
156148
auto v = make_gather_tensor(
157149
make_gmem_ptr((const Element*)params_.v_ptr + v_offset),
158150
make_shape(kv_len, params_.head_dim),
159-
make_stride(get<0>(params_.v_stride), _1{}),
151+
select<0, 2>(params_.v_stride),
160152
idx_to_slot);
161153
return make_tuple(k, v);
162154
}
163155
};
164156

165-
} // namespace llm
157+
} // namespace llm

src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ torch::Tensor mla_pagedkv_sm80(
8484
// construct attention params
8585
MLAPagedKVParams params;
8686
params.q_ptr = q.const_data_ptr();
87-
params.q_stride = make_stride(q.stride(0), q.stride(1));
87+
params.q_stride = make_stride(q.stride(0), q.stride(1), _1{});
8888
params.kv_ptr = kv_cache.const_data_ptr();
89-
params.kv_stride = make_stride(kv_cache.stride(0));
89+
params.kv_stride = make_stride(kv_cache.stride(0), _1{});
9090
params.q_rope_ptr = q_rope.const_data_ptr();
91-
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1));
91+
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1), _1{});
9292
params.k_rope_ptr = k_rope_cache.const_data_ptr();
93-
params.k_rope_stride = make_stride(k_rope_cache.stride(0));
93+
params.k_rope_stride = make_stride(k_rope_cache.stride(0), _1{});
9494

9595
params.o_ptr = out.mutable_data_ptr();
96-
params.o_stride = make_stride(out.stride(0), out.stride(1));
96+
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
9797

9898
params.batch_size = batch_size;
9999
params.max_q_len = max_q_len;
@@ -280,4 +280,4 @@ INSTANTIATE_TEST_SUITE_P(
280280
::testing::Values(64) // rope_head_dim
281281
));
282282

283-
} // namespace llm
283+
} // namespace llm

src/kernels/attention/mla_kernel_sm80_test.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,19 @@ torch::Tensor mla_sm80(
8383
// construct attention params
8484
MLAParams params;
8585
params.q_ptr = q.const_data_ptr();
86-
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2));
86+
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2), _1{});
8787
params.kv_ptr = kv.const_data_ptr();
88-
params.kv_stride = make_stride(kv.stride(0), kv.stride(1));
88+
params.kv_stride = make_stride(kv.stride(0), kv.stride(1), _1{});
8989

9090
params.q_rope_ptr = q_rope.const_data_ptr();
9191
params.q_rope_stride =
92-
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2));
92+
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), _1{});
9393
params.k_rope_ptr = k_rope.const_data_ptr();
94-
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1));
94+
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1), _1{});
9595

9696
params.o_ptr = out.mutable_data_ptr();
97-
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
97+
params.o_stride =
98+
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});
9899

99100
params.batch_size = batch_size;
100101
params.max_q_len = q_len;
@@ -193,4 +194,4 @@ INSTANTIATE_TEST_SUITE_P(
193194
::testing::Values(64) // rope_head_dim
194195
));
195196

196-
} // namespace llm
197+
} // namespace llm

src/kernels/attention/mla_params.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ struct MLAParamsCommon {
6464

6565
struct MLAParams : public MLAParamsCommon {
6666
// Q/O: (batch, seq, head, dim): last dimension is contiguous
67-
using Stride = cute::Stride<int64_t, int64_t, int64_t /*,_1*/>;
67+
using Stride = cute::Stride<int64_t, int64_t, int64_t, cute::_1>;
6868
// KV: (batch, seq, dim): last dimension is contiguous
69-
using KV_Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
69+
using KV_Stride = cute::Stride<int64_t, int64_t, cute::_1>;
7070

7171
Stride q_stride;
7272
Stride q_rope_stride;
@@ -84,9 +84,9 @@ struct MLAParams : public MLAParamsCommon {
8484
// paged KV cache + variable length sequence
8585
struct MLAPagedKVParams : public MLAParamsCommon {
8686
// Q/O: (seq, head, dim): last dimension is contiguous
87-
using Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
87+
using Stride = cute::Stride<int64_t, int64_t, cute::_1>;
8888
// KV: (seq, dim): last dimension is contiguous
89-
using KV_Stride = cute::Stride<int64_t /*,_1*/>;
89+
using KV_Stride = cute::Stride<int64_t, cute::_1>;
9090

9191
Stride q_stride;
9292
Stride q_rope_stride;
@@ -108,4 +108,4 @@ struct MLAPagedKVParams : public MLAParamsCommon {
108108
const int* __restrict__ block_cu_lens = nullptr;
109109
};
110110

111-
} // namespace llm
111+
} // namespace llm

0 commit comments

Comments
 (0)