Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/kernels/attention/attn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ void paged_kv_varlen_mha(
// construct attention params
MHAPagedKVParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride = make_stride(query.stride(0), query.stride(1));
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
params.k_ptr = key_cache.const_data_ptr();
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
params.v_ptr = value_cache.const_data_ptr();
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
params.v_stride =
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
params.alibi_slopes_ptr = alibi_slopes.has_value()
? alibi_slopes.value().const_data_ptr<float>()
: nullptr;
Expand Down
11 changes: 6 additions & 5 deletions src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ torch::Tensor mha_pagedkv_sm80(
// construct attention params
MHAPagedKVParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride = make_stride(query.stride(0), query.stride(1));
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
params.k_ptr = key_cache.const_data_ptr();
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
params.v_ptr = value_cache.const_data_ptr();
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
params.v_stride =
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
params.alibi_slopes_ptr = alibi_slopes.has_value()
? alibi_slopes.value().const_data_ptr<float>()
: nullptr;
Expand Down Expand Up @@ -243,4 +244,4 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(-1, 0, 10) // sliding window
));

} // namespace llm
} // namespace llm
12 changes: 7 additions & 5 deletions src/kernels/attention/mha_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@ torch::Tensor mha_sm80(
MHAParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride =
make_stride(query.stride(0), query.stride(1), query.stride(2));
make_stride(query.stride(0), query.stride(1), query.stride(2), _1{});
params.k_ptr = key.const_data_ptr();
params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2));
params.k_stride =
make_stride(key.stride(0), key.stride(1), key.stride(2), _1{});
params.v_ptr = value.const_data_ptr();
params.v_stride =
make_stride(value.stride(0), value.stride(1), value.stride(2));
make_stride(value.stride(0), value.stride(1), value.stride(2), _1{});
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
params.o_stride =
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});
params.alibi_slopes_ptr = alibi_slopes.has_value()
? alibi_slopes.value().const_data_ptr<float>()
: nullptr;
Expand Down Expand Up @@ -168,4 +170,4 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(-1, 0, 10) // sliding window
));

} // namespace llm
} // namespace llm
6 changes: 3 additions & 3 deletions src/kernels/attention/mha_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct MHAParamsCommon {

struct MHAParams : public MHAParamsCommon {
// (batch, seq, head, dim): last dimension is contiguous
using Stride = cute::Stride<int64_t, int64_t, int64_t /*,_1*/>;
using Stride = cute::Stride<int64_t, int64_t, int64_t, cute::_1>;

Stride q_stride;
Stride k_stride;
Expand All @@ -97,7 +97,7 @@ struct MHAParams : public MHAParamsCommon {
// paged KV cache + variable length sequence
struct MHAPagedKVParams : public MHAParamsCommon {
// (seq, head, dim): last dimension is contiguous
using Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
using Stride = cute::Stride<int64_t, int64_t, cute::_1>;

Stride q_stride;
Stride k_stride;
Expand All @@ -116,4 +116,4 @@ struct MHAPagedKVParams : public MHAParamsCommon {
const int* __restrict__ block_cu_lens = nullptr;
};

} // namespace llm
} // namespace llm
10 changes: 6 additions & 4 deletions src/kernels/attention/mha_sm80_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ void mha_bench_sm80(nvbench::state& state) {
MHAParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride =
make_stride(query.stride(0), query.stride(1), query.stride(2));
make_stride(query.stride(0), query.stride(1), query.stride(2), _1{});
params.k_ptr = key.const_data_ptr();
params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2));
params.k_stride =
make_stride(key.stride(0), key.stride(1), key.stride(2), _1{});
params.v_ptr = value.const_data_ptr();
params.v_stride =
make_stride(value.stride(0), value.stride(1), value.stride(2));
make_stride(value.stride(0), value.stride(1), value.stride(2), _1{});
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
params.o_stride =
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});
params.alibi_slopes_ptr =
alibi ? alibi_slopes.value().const_data_ptr<float>() : nullptr;
params.batch_size = batch_size;
Expand Down
9 changes: 5 additions & 4 deletions src/kernels/attention/mha_sm80_pagedkv_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ void mha_bench_sm80(nvbench::state& state) {
// construct attention params
MHAPagedKVParams params;
params.q_ptr = query.const_data_ptr();
params.q_stride = make_stride(query.stride(0), query.stride(1));
params.q_stride = make_stride(query.stride(0), query.stride(1), _1{});
params.k_ptr = key_cache.const_data_ptr();
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1));
params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{});
params.v_ptr = value_cache.const_data_ptr();
params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1));
params.v_stride =
make_stride(value_cache.stride(0), value_cache.stride(1), _1{});
params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});
params.alibi_slopes_ptr =
alibi ? alibi_slopes.value().const_data_ptr<float>() : nullptr;
params.batch_size = batch_size;
Expand Down
26 changes: 9 additions & 17 deletions src/kernels/attention/mha_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,14 @@ struct MHATile<MHAParams> {
auto q = make_gather_tensor(
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
make_shape(packed_len, params_.head_dim),
make_stride(
make_stride(get<1>(params_.q_stride), get<2>(params_.q_stride)),
_1{}),
make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)),
packed_idx_to_coord);

const auto o_offset = batch_idx_ * get<0>(params_.o_stride);
auto o = make_gather_tensor(
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
make_shape(packed_len, params_.head_dim),
make_stride(
make_stride(get<1>(params_.o_stride), get<2>(params_.o_stride)),
_1{}),
make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)),
packed_idx_to_coord);
return make_tuple(q, o);
}
Expand All @@ -72,12 +68,12 @@ struct MHATile<MHAParams> {
auto k =
make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset),
make_shape(params_.kv_len, params_.head_dim),
make_stride(get<1>(params_.k_stride), _1{}));
select<1, 3>(params_.k_stride));
// v[batch_idx, :, kv_head_idx, :]
auto v =
make_tensor(make_gmem_ptr((const Element*)params_.v_ptr + v_offset),
make_shape(params_.kv_len, params_.head_dim),
make_stride(get<1>(params_.v_stride), _1{}));
select<1, 3>(params_.v_stride));
return make_tuple(k, v);
}
};
Expand Down Expand Up @@ -112,18 +108,14 @@ struct MHATile<MHAPagedKVParams> {
auto q = make_gather_tensor(
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
make_shape(packed_len, params_.head_dim),
make_stride(
make_stride(get<0>(params_.q_stride), get<1>(params_.q_stride)),
_1{}),
make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)),
packed_idx_to_coord);

const auto o_offset = begin * get<0>(params_.o_stride);
auto o = make_gather_tensor(
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
make_shape(packed_len, params_.head_dim),
make_stride(
make_stride(get<0>(params_.o_stride), get<1>(params_.o_stride)),
_1{}),
make_stride(select<0, 1>(params_.o_stride), get<2>(params_.o_stride)),
packed_idx_to_coord);
return make_tuple(q, o);
}
Expand All @@ -148,18 +140,18 @@ struct MHATile<MHAPagedKVParams> {
auto k = make_gather_tensor(
make_gmem_ptr((const Element*)params_.k_ptr + k_offset),
make_shape(kv_len, params_.head_dim),
make_stride(get<0>(params_.k_stride), _1{}),
select<0, 2>(params_.k_stride),
idx_to_slot);

// v[:, kv_head_idx, :]
const auto v_offset = kv_head_idx_ * get<1>(params_.v_stride);
auto v = make_gather_tensor(
make_gmem_ptr((const Element*)params_.v_ptr + v_offset),
make_shape(kv_len, params_.head_dim),
make_stride(get<0>(params_.v_stride), _1{}),
select<0, 2>(params_.v_stride),
idx_to_slot);
return make_tuple(k, v);
}
};

} // namespace llm
} // namespace llm
12 changes: 6 additions & 6 deletions src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ torch::Tensor mla_pagedkv_sm80(
// construct attention params
MLAPagedKVParams params;
params.q_ptr = q.const_data_ptr();
params.q_stride = make_stride(q.stride(0), q.stride(1));
params.q_stride = make_stride(q.stride(0), q.stride(1), _1{});
params.kv_ptr = kv_cache.const_data_ptr();
params.kv_stride = make_stride(kv_cache.stride(0));
params.kv_stride = make_stride(kv_cache.stride(0), _1{});
params.q_rope_ptr = q_rope.const_data_ptr();
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1));
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1), _1{});
params.k_rope_ptr = k_rope_cache.const_data_ptr();
params.k_rope_stride = make_stride(k_rope_cache.stride(0));
params.k_rope_stride = make_stride(k_rope_cache.stride(0), _1{});

params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));
params.o_stride = make_stride(out.stride(0), out.stride(1), _1{});

params.batch_size = batch_size;
params.max_q_len = max_q_len;
Expand Down Expand Up @@ -280,4 +280,4 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(64) // rope_head_dim
));

} // namespace llm
} // namespace llm
13 changes: 7 additions & 6 deletions src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,19 @@ torch::Tensor mla_sm80(
// construct attention params
MLAParams params;
params.q_ptr = q.const_data_ptr();
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2));
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2), _1{});
params.kv_ptr = kv.const_data_ptr();
params.kv_stride = make_stride(kv.stride(0), kv.stride(1));
params.kv_stride = make_stride(kv.stride(0), kv.stride(1), _1{});

params.q_rope_ptr = q_rope.const_data_ptr();
params.q_rope_stride =
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2));
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), _1{});
params.k_rope_ptr = k_rope.const_data_ptr();
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1));
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1), _1{});

params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
params.o_stride =
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});

params.batch_size = batch_size;
params.max_q_len = q_len;
Expand Down Expand Up @@ -193,4 +194,4 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(64) // rope_head_dim
));

} // namespace llm
} // namespace llm
10 changes: 5 additions & 5 deletions src/kernels/attention/mla_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ struct MLAParamsCommon {

struct MLAParams : public MLAParamsCommon {
// Q/O: (batch, seq, head, dim): last dimension is contiguous
using Stride = cute::Stride<int64_t, int64_t, int64_t /*,_1*/>;
using Stride = cute::Stride<int64_t, int64_t, int64_t, cute::_1>;
// KV: (batch, seq, dim): last dimension is contiguous
using KV_Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
using KV_Stride = cute::Stride<int64_t, int64_t, cute::_1>;

Stride q_stride;
Stride q_rope_stride;
Expand All @@ -84,9 +84,9 @@ struct MLAParams : public MLAParamsCommon {
// paged KV cache + variable length sequence
struct MLAPagedKVParams : public MLAParamsCommon {
// Q/O: (seq, head, dim): last dimension is contiguous
using Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
using Stride = cute::Stride<int64_t, int64_t, cute::_1>;
// KV: (seq, dim): last dimension is contiguous
using KV_Stride = cute::Stride<int64_t /*,_1*/>;
using KV_Stride = cute::Stride<int64_t, cute::_1>;

Stride q_stride;
Stride q_rope_stride;
Expand All @@ -108,4 +108,4 @@ struct MLAPagedKVParams : public MLAParamsCommon {
const int* __restrict__ block_cu_lens = nullptr;
};

} // namespace llm
} // namespace llm
11 changes: 6 additions & 5 deletions src/kernels/attention/mla_sm80_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,19 @@ void mla_bench_sm80(nvbench::state& state) {
// construct attention params
MLAParams params;
params.q_ptr = q.const_data_ptr();
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2));
params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2), _1{});
params.kv_ptr = kv.const_data_ptr();
params.kv_stride = make_stride(kv.stride(0), kv.stride(1));
params.kv_stride = make_stride(kv.stride(0), kv.stride(1), _1{});

params.q_rope_ptr = q_rope.const_data_ptr();
params.q_rope_stride =
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2));
make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), _1{});
params.k_rope_ptr = k_rope.const_data_ptr();
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1));
params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1), _1{});

params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2));
params.o_stride =
make_stride(out.stride(0), out.stride(1), out.stride(2), _1{});

params.batch_size = batch_size;
params.max_q_len = q_len;
Expand Down
Loading