diff --git a/src/kernels/attention/attn_api.cpp b/src/kernels/attention/attn_api.cpp index bf78d26c..b5501af6 100644 --- a/src/kernels/attention/attn_api.cpp +++ b/src/kernels/attention/attn_api.cpp @@ -34,13 +34,14 @@ void paged_kv_varlen_mha( // construct attention params MHAPagedKVParams params; params.q_ptr = query.const_data_ptr(); - params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.q_stride = make_stride(query.stride(0), query.stride(1), _1{}); params.k_ptr = key_cache.const_data_ptr(); - params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{}); params.v_ptr = value_cache.const_data_ptr(); - params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.v_stride = + make_stride(value_cache.stride(0), value_cache.stride(1), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.o_stride = make_stride(out.stride(0), out.stride(1), _1{}); params.alibi_slopes_ptr = alibi_slopes.has_value() ? alibi_slopes.value().const_data_ptr() : nullptr; diff --git a/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu index 3e45bc15..eca9dc9e 100644 --- a/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu @@ -48,13 +48,14 @@ torch::Tensor mha_pagedkv_sm80( // construct attention params MHAPagedKVParams params; params.q_ptr = query.const_data_ptr(); - params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.q_stride = make_stride(query.stride(0), query.stride(1), _1{}); params.k_ptr = key_cache.const_data_ptr(); - params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{}); params.v_ptr = value_cache.const_data_ptr(); - params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.v_stride = + make_stride(value_cache.stride(0), value_cache.stride(1), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.o_stride = make_stride(out.stride(0), out.stride(1), _1{}); params.alibi_slopes_ptr = alibi_slopes.has_value() ? alibi_slopes.value().const_data_ptr() : nullptr; @@ -243,4 +244,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1, 0, 10) // sliding window )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_kernel_sm80_test.cu b/src/kernels/attention/mha_kernel_sm80_test.cu index 90630fef..ac154211 100644 --- a/src/kernels/attention/mha_kernel_sm80_test.cu +++ b/src/kernels/attention/mha_kernel_sm80_test.cu @@ -60,14 +60,16 @@ torch::Tensor mha_sm80( MHAParams params; params.q_ptr = query.const_data_ptr(); params.q_stride = - make_stride(query.stride(0), query.stride(1), query.stride(2)); + make_stride(query.stride(0), query.stride(1), query.stride(2), _1{}); params.k_ptr = key.const_data_ptr(); - params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2)); + params.k_stride = + make_stride(key.stride(0), key.stride(1), key.stride(2), _1{}); params.v_ptr = value.const_data_ptr(); params.v_stride = - make_stride(value.stride(0), value.stride(1), value.stride(2)); + make_stride(value.stride(0), value.stride(1), value.stride(2), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + params.o_stride = + make_stride(out.stride(0), out.stride(1), out.stride(2), _1{}); params.alibi_slopes_ptr = alibi_slopes.has_value() ? alibi_slopes.value().const_data_ptr() : nullptr; @@ -168,4 +170,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1, 0, 10) // sliding window )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_params.h b/src/kernels/attention/mha_params.h index 6f55eab3..a5b86dae 100644 --- a/src/kernels/attention/mha_params.h +++ b/src/kernels/attention/mha_params.h @@ -82,7 +82,7 @@ struct MHAParamsCommon { struct MHAParams : public MHAParamsCommon { // (batch, seq, head, dim): last dimension is contiguous - using Stride = cute::Stride; + using Stride = cute::Stride; Stride q_stride; Stride k_stride; @@ -97,7 +97,7 @@ struct MHAParams : public MHAParamsCommon { // paged KV cache + variable length sequence struct MHAPagedKVParams : public MHAParamsCommon { // (seq, head, dim): last dimension is contiguous - using Stride = cute::Stride; + using Stride = cute::Stride; Stride q_stride; Stride k_stride; @@ -116,4 +116,4 @@ struct MHAPagedKVParams : public MHAParamsCommon { const int* __restrict__ block_cu_lens = nullptr; }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_sm80_bench.cu b/src/kernels/attention/mha_sm80_bench.cu index be41448c..a279453a 100644 --- a/src/kernels/attention/mha_sm80_bench.cu +++ b/src/kernels/attention/mha_sm80_bench.cu @@ -48,14 +48,16 @@ void mha_bench_sm80(nvbench::state& state) { MHAParams params; params.q_ptr = query.const_data_ptr(); params.q_stride = - make_stride(query.stride(0), query.stride(1), query.stride(2)); + make_stride(query.stride(0), query.stride(1), query.stride(2), _1{}); params.k_ptr = key.const_data_ptr(); - params.k_stride = make_stride(key.stride(0), key.stride(1), key.stride(2)); + params.k_stride = + make_stride(key.stride(0), key.stride(1), key.stride(2), _1{}); params.v_ptr = value.const_data_ptr(); params.v_stride = - make_stride(value.stride(0), value.stride(1), value.stride(2)); + make_stride(value.stride(0), value.stride(1), value.stride(2), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + params.o_stride = + make_stride(out.stride(0), out.stride(1), out.stride(2), _1{}); params.alibi_slopes_ptr = alibi ? alibi_slopes.value().const_data_ptr() : nullptr; params.batch_size = batch_size; diff --git a/src/kernels/attention/mha_sm80_pagedkv_bench.cu b/src/kernels/attention/mha_sm80_pagedkv_bench.cu index 66ebc222..08891818 100644 --- a/src/kernels/attention/mha_sm80_pagedkv_bench.cu +++ b/src/kernels/attention/mha_sm80_pagedkv_bench.cu @@ -92,13 +92,14 @@ void mha_bench_sm80(nvbench::state& state) { // construct attention params MHAPagedKVParams params; params.q_ptr = query.const_data_ptr(); - params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.q_stride = make_stride(query.stride(0), query.stride(1), _1{}); params.k_ptr = key_cache.const_data_ptr(); - params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1), _1{}); params.v_ptr = value_cache.const_data_ptr(); - params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.v_stride = + make_stride(value_cache.stride(0), value_cache.stride(1), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.o_stride = make_stride(out.stride(0), out.stride(1), _1{}); params.alibi_slopes_ptr = alibi ? alibi_slopes.value().const_data_ptr() : nullptr; params.batch_size = batch_size; diff --git a/src/kernels/attention/mha_tile.h b/src/kernels/attention/mha_tile.h index 8e04f927..dc535f93 100644 --- a/src/kernels/attention/mha_tile.h +++ b/src/kernels/attention/mha_tile.h @@ -44,18 +44,14 @@ struct MHATile { auto q = make_gather_tensor( make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(packed_len, params_.head_dim), - make_stride( - make_stride(get<1>(params_.q_stride), get<2>(params_.q_stride)), - _1{}), + make_stride(select<1, 2>(params_.q_stride), get<3>(params_.q_stride)), packed_idx_to_coord); const auto o_offset = batch_idx_ * get<0>(params_.o_stride); auto o = make_gather_tensor( make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(packed_len, params_.head_dim), - make_stride( - make_stride(get<1>(params_.o_stride), get<2>(params_.o_stride)), - _1{}), + make_stride(select<1, 2>(params_.o_stride), get<3>(params_.o_stride)), packed_idx_to_coord); return make_tuple(q, o); } @@ -72,12 +68,12 @@ struct MHATile { auto k = make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + k_offset), make_shape(params_.kv_len, params_.head_dim), - make_stride(get<1>(params_.k_stride), _1{})); + select<1, 3>(params_.k_stride)); // v[batch_idx, :, kv_head_idx, :] auto v = make_tensor(make_gmem_ptr((const Element*)params_.v_ptr + v_offset), make_shape(params_.kv_len, params_.head_dim), - make_stride(get<1>(params_.v_stride), _1{})); + select<1, 3>(params_.v_stride)); return make_tuple(k, v); } }; @@ -112,18 +108,14 @@ struct MHATile { auto q = make_gather_tensor( make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(packed_len, params_.head_dim), - make_stride( - make_stride(get<0>(params_.q_stride), get<1>(params_.q_stride)), - _1{}), + make_stride(select<0, 1>(params_.q_stride), get<2>(params_.q_stride)), packed_idx_to_coord); const auto o_offset = begin * get<0>(params_.o_stride); auto o = make_gather_tensor( make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(packed_len, params_.head_dim), - make_stride( - make_stride(get<0>(params_.o_stride), get<1>(params_.o_stride)), - _1{}), + make_stride(select<0, 1>(params_.o_stride), get<2>(params_.o_stride)), packed_idx_to_coord); return make_tuple(q, o); } @@ -148,7 +140,7 @@ struct MHATile { auto k = make_gather_tensor( make_gmem_ptr((const Element*)params_.k_ptr + k_offset), make_shape(kv_len, params_.head_dim), - make_stride(get<0>(params_.k_stride), _1{}), + select<0, 2>(params_.k_stride), idx_to_slot); // v[:, kv_head_idx, :] @@ -156,10 +148,10 @@ struct MHATile { auto v = make_gather_tensor( make_gmem_ptr((const Element*)params_.v_ptr + v_offset), make_shape(kv_len, params_.head_dim), - make_stride(get<0>(params_.v_stride), _1{}), + select<0, 2>(params_.v_stride), idx_to_slot); return make_tuple(k, v); } }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu index 0910bf0b..c64e928e 100644 --- a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu @@ -84,16 +84,16 @@ torch::Tensor mla_pagedkv_sm80( // construct attention params MLAPagedKVParams params; params.q_ptr = q.const_data_ptr(); - params.q_stride = make_stride(q.stride(0), q.stride(1)); + params.q_stride = make_stride(q.stride(0), q.stride(1), _1{}); params.kv_ptr = kv_cache.const_data_ptr(); - params.kv_stride = make_stride(kv_cache.stride(0)); + params.kv_stride = make_stride(kv_cache.stride(0), _1{}); params.q_rope_ptr = q_rope.const_data_ptr(); - params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1)); + params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1), _1{}); params.k_rope_ptr = k_rope_cache.const_data_ptr(); - params.k_rope_stride = make_stride(k_rope_cache.stride(0)); + params.k_rope_stride = make_stride(k_rope_cache.stride(0), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.o_stride = make_stride(out.stride(0), out.stride(1), _1{}); params.batch_size = batch_size; params.max_q_len = max_q_len; @@ -280,4 +280,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(64) // rope_head_dim )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index 2b6c4098..ad72abf5 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -83,18 +83,19 @@ torch::Tensor mla_sm80( // construct attention params MLAParams params; params.q_ptr = q.const_data_ptr(); - params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2), _1{}); params.kv_ptr = kv.const_data_ptr(); - params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1), _1{}); params.q_rope_ptr = q_rope.const_data_ptr(); params.q_rope_stride = - make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), _1{}); params.k_rope_ptr = k_rope.const_data_ptr(); - params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + params.o_stride = + make_stride(out.stride(0), out.stride(1), out.stride(2), _1{}); params.batch_size = batch_size; params.max_q_len = q_len; @@ -193,4 +194,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(64) // rope_head_dim )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h index d77b50c7..fdb0113a 100644 --- a/src/kernels/attention/mla_params.h +++ b/src/kernels/attention/mla_params.h @@ -64,9 +64,9 @@ struct MLAParamsCommon { struct MLAParams : public MLAParamsCommon { // Q/O: (batch, seq, head, dim): last dimension is contiguous - using Stride = cute::Stride; + using Stride = cute::Stride; // KV: (batch, seq, dim): last dimension is contiguous - using KV_Stride = cute::Stride; + using KV_Stride = cute::Stride; Stride q_stride; Stride q_rope_stride; @@ -84,9 +84,9 @@ struct MLAParams : public MLAParamsCommon { // paged KV cache + variable length sequence struct MLAPagedKVParams : public MLAParamsCommon { // Q/O: (seq, head, dim): last dimension is contiguous - using Stride = cute::Stride; + using Stride = cute::Stride; // KV: (seq, dim): last dimension is contiguous - using KV_Stride = cute::Stride; + using KV_Stride = cute::Stride; Stride q_stride; Stride q_rope_stride; @@ -108,4 +108,4 @@ struct MLAPagedKVParams : public MLAParamsCommon { const int* __restrict__ block_cu_lens = nullptr; }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index 030bb52a..29065c5c 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -64,18 +64,19 @@ void mla_bench_sm80(nvbench::state& state) { // construct attention params MLAParams params; params.q_ptr = q.const_data_ptr(); - params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2), _1{}); params.kv_ptr = kv.const_data_ptr(); - params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1), _1{}); params.q_rope_ptr = q_rope.const_data_ptr(); params.q_rope_stride = - make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2), _1{}); params.k_rope_ptr = k_rope.const_data_ptr(); - params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1), _1{}); params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + params.o_stride = + make_stride(out.stride(0), out.stride(1), out.stride(2), _1{}); params.batch_size = batch_size; params.max_q_len = q_len; diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h index 18ef1bc8..9b139f62 100644 --- a/src/kernels/attention/mla_tile.h +++ b/src/kernels/attention/mla_tile.h @@ -32,20 +32,20 @@ struct MLATile { auto q = make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(q_packed_len, params_.head_dim), - make_stride(get<2>(params_.q_stride), _1{})); + select<2, 3>(params_.q_stride)); // (batch, seq, head, rope_head_dim) const auto q_rope_offset = batch_idx_ * get<0>(params_.q_rope_stride); auto q_rope = make_tensor( make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), make_shape(q_packed_len, params_.rope_head_dim), - make_stride(get<2>(params_.q_rope_stride), _1{})); + select<2, 3>(params_.q_rope_stride)); // (batch, seq, head, dim) const auto o_offset = batch_idx_ * get<0>(params_.o_stride); auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(q_packed_len, params_.head_dim), - make_stride(get<2>(params_.o_stride), _1{})); + select<2, 3>(params_.o_stride)); return make_tuple(q, q_rope, o); } @@ -55,18 +55,18 @@ struct MLATile { CUTE_HOST_DEVICE auto get_kv_tile() const { // (batch, seq, dim) const auto kv_offset = batch_idx_ * get<0>(params_.kv_stride); - // k[batch_idx, :, kv_head_idx, :] + // k[batch_idx, :, :] auto kv = make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), make_shape(params_.kv_len, params_.head_dim), - make_stride(get<1>(params_.kv_stride), _1{})); + select<1, 2>(params_.kv_stride)); // (batch, seq, rope_head_dim) const auto k_rope_offset = batch_idx_ * get<0>(params_.k_rope_stride); auto k_rope = make_tensor( make_gmem_ptr((const Element*)params_.k_rope_ptr + k_rope_offset), make_shape(params_.kv_len, params_.rope_head_dim), - make_stride(get<1>(params_.k_rope_stride), _1{})); + select<1, 2>(params_.k_rope_stride)); return make_tuple(kv, k_rope); } }; @@ -95,20 +95,20 @@ struct MLATile { auto q = make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(q_packed_len, params_.head_dim), - make_stride(get<1>(params_.q_stride), _1{})); + select<1, 2>(params_.q_stride)); // (seq, head, rope_head_dim) const auto q_rope_offset = begin * get<0>(params_.q_rope_stride); auto q_rope = make_tensor( make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), make_shape(q_packed_len, params_.rope_head_dim), - make_stride(get<1>(params_.q_rope_stride), _1{})); + select<1, 2>(params_.q_rope_stride)); // (seq, head, dim) const auto o_offset = begin * get<0>(params_.o_stride); auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(q_packed_len, params_.head_dim), - make_stride(get<1>(params_.o_stride), _1{})); + select<1, 2>(params_.o_stride)); return make_tuple(q, q_rope, o); } @@ -131,17 +131,17 @@ struct MLATile { // kv: (seq, dim) auto kv = make_gather_tensor(make_gmem_ptr((const Element*)params_.kv_ptr), make_shape(kv_len, params_.head_dim), - make_stride(get<0>(params_.kv_stride), _1{}), + params_.kv_stride, idx_to_slot); // k_rope: (seq, rope_head_dim) auto k_rope = make_gather_tensor(make_gmem_ptr((const Element*)params_.k_rope_ptr), make_shape(kv_len, params_.rope_head_dim), - make_stride(get<0>(params_.k_rope_stride), _1{}), + params_.k_rope_stride, idx_to_slot); return make_tuple(kv, k_rope); } }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh index 6603e1ce..091a037f 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80.cuh @@ -126,9 +126,9 @@ struct GEMMSharedStorageSM80 { }; struct GEMMParams { - using AStride = Stride; - using BStride = Stride; - using CStride = Stride; + using AStride = Stride; + using BStride = Stride; + using CStride = Stride; // A: (m, k) const void* __restrict__ a_ptr = nullptr; @@ -206,14 +206,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( // A: (M, K), k-major auto A = make_gather_tensor(make_gmem_ptr((const DType*)params.a_ptr), make_shape(M, K), - make_stride(get<0>(params.a_stride), _1{}), + params.a_stride, idx_to_t_idx); // B: (N, K), k-major const auto b_offset = expert_id * get<0>(params.b_stride); auto B = make_tensor(make_gmem_ptr((const DType*)params.b_ptr + b_offset), make_shape(N, K), - make_stride(get<1>(params.b_stride), _1{})); + select<1, 2>(params.b_stride)); // C: (M, N), n-major auto idx_to_f_idx = [sorted_token_idxes](int idx) { @@ -221,7 +221,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80( }; auto C = make_gather_tensor(make_gmem_ptr((DType*)params.c_ptr), make_shape(M, N), - make_stride(get<0>(params.c_stride), _1{}), + params.c_stride, idx_to_f_idx); auto max_coord_mk = make_coord(M, K); diff --git a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu index 34701156..2fad39a5 100644 --- a/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu +++ b/src/kernels/gemm/grouped_gemm_kernel_sm80_test.cu @@ -82,11 +82,11 @@ torch::Tensor grouped_gemm_sm80(const torch::Tensor& a, // (m, k) // construct params GEMMParams params; params.a_ptr = a.const_data_ptr(); - params.a_stride = make_stride(a.stride(0)); + params.a_stride = make_stride(a.stride(0), _1{}); params.b_ptr = w.const_data_ptr(); - params.b_stride = make_stride(w.stride(0), w.stride(1)); + params.b_stride = make_stride(w.stride(0), w.stride(1), _1{}); params.c_ptr = out.mutable_data_ptr(); - params.c_stride = make_stride(out.stride(0)); + params.c_stride = make_stride(out.stride(0), _1{}); params.sorted_token_idxes_ptr = sorted_token_idex.const_data_ptr(); params.expert_ids_ptr = expert_ids.const_data_ptr();