diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 0cb4371e..80de7486 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -6,6 +6,8 @@ include(cc_test) cc_library( NAME attention.template + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} HDRS fast_math.h layout_convertor.h @@ -41,13 +43,13 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu") cc_library( NAME attention.kernels + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} HDRS attn_api.h SRCS attn_api.cpp ${GENERATED_SRC_FILES} - INCLUDES - ${CMAKE_CURRENT_SOURCE_DIR} DEPS :attention.template glog::glog @@ -56,9 +58,11 @@ cc_library( cc_test( NAME mha_kernel_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS # sm80_mha_test.cu - sm80_mha_pagedkv_test.cu + tests/sm80_mha_pagedkv_test.cu DEPS :attention.kernels absl::random_random @@ -69,33 +73,39 @@ cc_test( cc_test( NAME sm120_fmha_kernel_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - sm120_fmha_test.cu + tests/sm120_fmha_test.cu DEPS :attention.template - absl::random_random :gtest_main + absl::random_random torch ) cc_test( NAME mla_kernel_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - sm80_mla_test.cu - sm80_mla_pagedkv_test.cu + tests/sm80_mla_test.cu + tests/sm80_mla_pagedkv_test.cu DEPS :attention.kernels - absl::random_random :gtest_main + absl::random_random torch ) cc_test( NAME attn_combine_kernel_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - attn_combine_kernel_test.cu + tests/attn_combine_kernel_test.cu DEPS :attention.template absl::random_random @@ -103,43 +113,51 @@ cc_test( torch ) -nvbench_binary( - NAME - sm80_mha_bench - SRCS - sm80_mha_bench.cu - DEPS - :attention.template -) - -nvbench_binary( - NAME - sm80_mha_pagedkv_bench - SRCS - sm80_mha_pagedkv_bench.cu - DEPS - absl::random_random - :attention.template -) - -nvbench_binary( - NAME - sm80_mla_bench - SRCS - mla_sm80_bench.cu - DEPS - :attention.template -) - cc_test( NAME tma_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - sm120_tma_block_load_test.cu + tests/sm120_tma_block_load_test.cu DEPS :gtest_main :cutlass torch ) -add_subdirectory(tools) +# nvbench_binary( +# NAME +# sm80_mha_bench +# INCLUDES +# ${CMAKE_CURRENT_SOURCE_DIR} +# SRCS +# sm80_mha_bench.cu +# DEPS +# :attention.template +# ) + +# nvbench_binary( +# NAME +# sm80_mha_pagedkv_bench +# INCLUDES +# ${CMAKE_CURRENT_SOURCE_DIR} +# SRCS +# sm80_mha_pagedkv_bench.cu +# DEPS +# absl::random_random +# :attention.template +# ) + +# nvbench_binary( +# NAME +# sm80_mla_bench +# INCLUDES +# ${CMAKE_CURRENT_SOURCE_DIR} +# SRCS +# mla_sm80_bench.cu +# DEPS +# :attention.template +# ) + +# add_subdirectory(tools) diff --git a/src/kernels/attention/attn_api.cpp b/src/kernels/attention/attn_api.cpp index 0853b77d..3f7a60fd 100644 --- a/src/kernels/attention/attn_api.cpp +++ b/src/kernels/attention/attn_api.cpp @@ -2,10 +2,11 @@ #include -#include "cute/layout.hpp" +#include + +#include "common/static_dispatch.h" +#include "device/sm80_mha_dispatch.cuh" #include "mha_params.h" -#include "sm80_mha_dispatch.cuh" -#include "static_dispatch.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_mha_bench.cu b/src/kernels/attention/bench/sm80_mha_bench.cu similarity index 100% rename from src/kernels/attention/sm80_mha_bench.cu rename to src/kernels/attention/bench/sm80_mha_bench.cu diff --git a/src/kernels/attention/sm80_mha_pagedkv_bench.cu b/src/kernels/attention/bench/sm80_mha_pagedkv_bench.cu similarity index 100% rename from src/kernels/attention/sm80_mha_pagedkv_bench.cu rename to src/kernels/attention/bench/sm80_mha_pagedkv_bench.cu diff --git a/src/kernels/attention/sm80_mla_bench.cu b/src/kernels/attention/bench/sm80_mla_bench.cu similarity index 100% rename from src/kernels/attention/sm80_mla_bench.cu rename to src/kernels/attention/bench/sm80_mla_bench.cu diff --git a/src/kernels/attention/sm120_collective_epilogue.cuh b/src/kernels/attention/collective/sm120_collective_epilogue.cuh similarity index 98% rename from src/kernels/attention/sm120_collective_epilogue.cuh rename to src/kernels/attention/collective/sm120_collective_epilogue.cuh index bfaf0726..9f88995d 100644 --- a/src/kernels/attention/sm120_collective_epilogue.cuh +++ b/src/kernels/attention/collective/sm120_collective_epilogue.cuh @@ -7,8 +7,8 @@ #include #include -#include "fast_cast.cuh" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm120_collective_fmha_mainloop_ws.cuh b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh similarity index 99% rename from src/kernels/attention/sm120_collective_fmha_mainloop_ws.cuh rename to src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh index cc13cc8b..1e1027cc 100644 --- a/src/kernels/attention/sm120_collective_fmha_mainloop_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh @@ -9,9 +9,9 @@ #include #include -#include "fast_cast.cuh" -#include "layout_convertor.h" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/layout_convertor.h" +#include "common/safe_copy.h" #include "sm120_collective_load_cpasync_ws.cuh" namespace llm { diff --git a/src/kernels/attention/sm120_collective_load_cpasync_ws.cuh b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh similarity index 99% rename from src/kernels/attention/sm120_collective_load_cpasync_ws.cuh rename to src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh index 1c3d944e..d4180ed8 100644 --- a/src/kernels/attention/sm120_collective_load_cpasync_ws.cuh +++ b/src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh @@ -9,7 +9,7 @@ #include #include -#include "safe_copy.h" +#include "common/safe_copy.h" namespace llm { diff --git a/src/kernels/attention/sm80_collective_epilogue.cuh b/src/kernels/attention/collective/sm80_collective_epilogue.cuh similarity index 98% rename from src/kernels/attention/sm80_collective_epilogue.cuh rename to src/kernels/attention/collective/sm80_collective_epilogue.cuh index 4840eddc..05329a3f 100644 --- a/src/kernels/attention/sm80_collective_epilogue.cuh +++ b/src/kernels/attention/collective/sm80_collective_epilogue.cuh @@ -7,8 +7,8 @@ #include #include -#include "fast_cast.cuh" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_collective_mha.cuh b/src/kernels/attention/collective/sm80_collective_mha.cuh similarity index 99% rename from src/kernels/attention/sm80_collective_mha.cuh rename to src/kernels/attention/collective/sm80_collective_mha.cuh index a609e9e0..19e6a6be 100644 --- a/src/kernels/attention/sm80_collective_mha.cuh +++ b/src/kernels/attention/collective/sm80_collective_mha.cuh @@ -8,9 +8,9 @@ #include #include -#include "fast_cast.cuh" -#include "layout_convertor.h" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/layout_convertor.h" +#include "common/safe_copy.h" namespace llm { diff --git a/src/kernels/attention/sm80_collective_mla.cuh b/src/kernels/attention/collective/sm80_collective_mla.cuh similarity index 99% rename from src/kernels/attention/sm80_collective_mla.cuh rename to src/kernels/attention/collective/sm80_collective_mla.cuh index 55af9d33..5291c294 100644 --- a/src/kernels/attention/sm80_collective_mla.cuh +++ b/src/kernels/attention/collective/sm80_collective_mla.cuh @@ -8,10 +8,10 @@ #include #include -#include "fast_cast.cuh" -#include "layout_convertor.h" -#include "mask.h" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/layout_convertor.h" +#include "common/mask.h" +#include "common/safe_copy.h" namespace llm { diff --git a/src/kernels/attention/sm80_collective_mla_epilogue.cuh b/src/kernels/attention/collective/sm80_collective_mla_epilogue.cuh similarity index 98% rename from src/kernels/attention/sm80_collective_mla_epilogue.cuh rename to src/kernels/attention/collective/sm80_collective_mla_epilogue.cuh index c53cbba7..6bf6c2f6 100644 --- a/src/kernels/attention/sm80_collective_mla_epilogue.cuh +++ b/src/kernels/attention/collective/sm80_collective_mla_epilogue.cuh @@ -7,8 +7,8 @@ #include #include -#include "fast_cast.cuh" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/gemm/fast_cast.cuh b/src/kernels/attention/common/fast_cast.cuh similarity index 100% rename from src/kernels/gemm/fast_cast.cuh rename to src/kernels/attention/common/fast_cast.cuh diff --git a/src/kernels/attention/fast_math.h b/src/kernels/attention/common/fast_math.h similarity index 100% rename from src/kernels/attention/fast_math.h rename to src/kernels/attention/common/fast_math.h diff --git a/src/kernels/attention/gather_tensor.hpp b/src/kernels/attention/common/gather_tensor.h similarity index 98% rename from src/kernels/attention/gather_tensor.hpp rename to src/kernels/attention/common/gather_tensor.h index 79ca581b..594fb8bc 100644 --- a/src/kernels/attention/gather_tensor.hpp +++ b/src/kernels/attention/common/gather_tensor.h @@ -2,9 +2,9 @@ // https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp #pragma once -#include "cute/layout.hpp" -#include "cute/layout_composed.hpp" -#include "cute/tensor.hpp" +#include +#include +#include namespace llm { using namespace cute; diff --git a/src/kernels/attention/gather_tma_copy.h b/src/kernels/attention/common/gather_tma_copy.h similarity index 100% rename from src/kernels/attention/gather_tma_copy.h rename to src/kernels/attention/common/gather_tma_copy.h diff --git a/src/kernels/attention/gather_tma_tensor.hpp b/src/kernels/attention/common/gather_tma_tensor.h similarity index 91% rename from src/kernels/attention/gather_tma_tensor.hpp rename to src/kernels/attention/common/gather_tma_tensor.h index 3dc29ccc..ad508fe0 100644 --- a/src/kernels/attention/gather_tma_tensor.hpp +++ b/src/kernels/attention/common/gather_tma_tensor.h @@ -24,13 +24,11 @@ struct GatherArithmeticTupleIterator { const Transform& transform) : coord_(coord), transform_(transform) {} - CUTE_HOST_DEVICE constexpr auto coord() const { + CUTE_HOST_DEVICE constexpr auto operator*() const { // apply the transform to the coordinate return transform_(coord_); } - CUTE_HOST_DEVICE constexpr auto operator*() const { return coord(); } - template CUTE_HOST_DEVICE constexpr auto operator+(const Coord& c) const { auto coord = coord_ + c; @@ -67,7 +65,7 @@ template CUTE_HOST_DEVICE void print( const GatherArithmeticTupleIterator& iter) { printf("GatherArithTuple"); - print(iter.coord()); + print(*iter); } #if !defined(__CUDACC_RTC__) @@ -75,7 +73,7 @@ template CUTE_HOST std::ostream& operator<<( std::ostream& os, const GatherArithmeticTupleIterator& iter) { - return os << "GatherArithTuple" << iter.coord(); + return os << "GatherArithTuple" << *iter; } #endif diff --git a/src/kernels/attention/layout_convertor.h b/src/kernels/attention/common/layout_convertor.h similarity index 100% rename from src/kernels/attention/layout_convertor.h rename to src/kernels/attention/common/layout_convertor.h diff --git a/src/kernels/attention/mask.h b/src/kernels/attention/common/mask.h similarity index 100% rename from src/kernels/attention/mask.h rename to src/kernels/attention/common/mask.h diff --git a/src/kernels/attention/online_softmax.cuh b/src/kernels/attention/common/online_softmax.cuh similarity index 99% rename from src/kernels/attention/online_softmax.cuh rename to src/kernels/attention/common/online_softmax.cuh index 5f4ec6b7..6aeac2d1 100644 --- a/src/kernels/attention/online_softmax.cuh +++ b/src/kernels/attention/common/online_softmax.cuh @@ -161,4 +161,4 @@ struct OnlineSoftmax { } }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/safe_copy.h b/src/kernels/attention/common/safe_copy.h similarity index 99% rename from src/kernels/attention/safe_copy.h rename to src/kernels/attention/common/safe_copy.h index 8312fe1c..14f6881c 100644 --- a/src/kernels/attention/safe_copy.h +++ b/src/kernels/attention/common/safe_copy.h @@ -1,11 +1,10 @@ #pragma once #include +#include +#include #include -#include "cute/config.hpp" -#include "cute/layout.hpp" - namespace cute { namespace detail { diff --git a/src/kernels/attention/static_dispatch.h b/src/kernels/attention/common/static_dispatch.h similarity index 99% rename from src/kernels/attention/static_dispatch.h rename to src/kernels/attention/common/static_dispatch.h index e95f8ab3..621eeb4f 100644 --- a/src/kernels/attention/static_dispatch.h +++ b/src/kernels/attention/common/static_dispatch.h @@ -45,4 +45,4 @@ namespace llm { } \ }() -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/tile_scheduler.cuh b/src/kernels/attention/common/tile_scheduler.cuh similarity index 100% rename from src/kernels/attention/tile_scheduler.cuh rename to src/kernels/attention/common/tile_scheduler.cuh diff --git a/src/kernels/attention/sm120_fmha_dispatch.cuh b/src/kernels/attention/device/sm120_fmha_dispatch.cuh similarity index 97% rename from src/kernels/attention/sm120_fmha_dispatch.cuh rename to src/kernels/attention/device/sm120_fmha_dispatch.cuh index 5ca96aa5..56b79157 100644 --- a/src/kernels/attention/sm120_fmha_dispatch.cuh +++ b/src/kernels/attention/device/sm120_fmha_dispatch.cuh @@ -3,7 +3,7 @@ #include #include -#include "static_dispatch.h" +#include "common/static_dispatch.h" namespace llm { // forward declaration diff --git a/src/kernels/attention/sm120_fmha_launch.cuh b/src/kernels/attention/device/sm120_fmha_launch.cuh similarity index 94% rename from src/kernels/attention/sm120_fmha_launch.cuh rename to src/kernels/attention/device/sm120_fmha_launch.cuh index dff72f9b..4975f8f4 100644 --- a/src/kernels/attention/sm120_fmha_launch.cuh +++ b/src/kernels/attention/device/sm120_fmha_launch.cuh @@ -6,10 +6,10 @@ #include #include -#include "sm120_collective_epilogue.cuh" -#include "sm120_collective_fmha_mainloop_ws.cuh" -#include "sm120_kernel_fmha_ws.cuh" -#include "tile_scheduler.cuh" +#include "collective/sm120_collective_epilogue.cuh" +#include "collective/sm120_collective_fmha_mainloop_ws.cuh" +#include "common/tile_scheduler.cuh" +#include "kernel/sm120_kernel_fmha_ws.cuh" namespace llm { diff --git a/src/kernels/attention/sm80_mha_dispatch.cuh b/src/kernels/attention/device/sm80_mha_dispatch.cuh similarity index 97% rename from src/kernels/attention/sm80_mha_dispatch.cuh rename to src/kernels/attention/device/sm80_mha_dispatch.cuh index 30136b7a..e64cde04 100644 --- a/src/kernels/attention/sm80_mha_dispatch.cuh +++ b/src/kernels/attention/device/sm80_mha_dispatch.cuh @@ -3,7 +3,7 @@ #include #include -#include "static_dispatch.h" +#include "common/static_dispatch.h" namespace llm { // forward declaration diff --git a/src/kernels/attention/sm80_mha_launch.cuh b/src/kernels/attention/device/sm80_mha_launch.cuh similarity index 95% rename from src/kernels/attention/sm80_mha_launch.cuh rename to src/kernels/attention/device/sm80_mha_launch.cuh index f26f71e8..89bffb63 100644 --- a/src/kernels/attention/sm80_mha_launch.cuh +++ b/src/kernels/attention/device/sm80_mha_launch.cuh @@ -6,10 +6,10 @@ #include #include -#include "sm80_collective_epilogue.cuh" -#include "sm80_collective_mha.cuh" -#include "sm80_kernel_mha.cuh" -#include "tile_scheduler.cuh" +#include "collective/sm80_collective_epilogue.cuh" +#include "collective/sm80_collective_mha.cuh" +#include "common/tile_scheduler.cuh" +#include "kernel/sm80_kernel_mha.cuh" namespace llm { diff --git a/src/kernels/attention/sm80_mla_dispatch.cuh b/src/kernels/attention/device/sm80_mla_dispatch.cuh similarity index 98% rename from src/kernels/attention/sm80_mla_dispatch.cuh rename to src/kernels/attention/device/sm80_mla_dispatch.cuh index fb4ebe35..82de3d8d 100644 --- a/src/kernels/attention/sm80_mla_dispatch.cuh +++ b/src/kernels/attention/device/sm80_mla_dispatch.cuh @@ -3,8 +3,8 @@ #include #include +#include "common/static_dispatch.h" #include "sm80_mla_launch.cuh" -#include "static_dispatch.h" namespace llm { diff --git a/src/kernels/attention/sm80_mla_launch.cuh b/src/kernels/attention/device/sm80_mla_launch.cuh similarity index 95% rename from src/kernels/attention/sm80_mla_launch.cuh rename to src/kernels/attention/device/sm80_mla_launch.cuh index 8dd8975c..33f0aad9 100644 --- a/src/kernels/attention/sm80_mla_launch.cuh +++ b/src/kernels/attention/device/sm80_mla_launch.cuh @@ -6,10 +6,10 @@ #include #include -#include "sm80_collective_mla.cuh" -#include "sm80_collective_mla_epilogue.cuh" -#include "sm80_kernel_mla.cuh" -#include "tile_scheduler.cuh" +#include "collective/sm80_collective_mla.cuh" +#include "collective/sm80_collective_mla_epilogue.cuh" +#include "common/tile_scheduler.cuh" +#include "kernel/sm80_kernel_mla.cuh" namespace llm { diff --git a/src/kernels/attention/generate_instantiation_cu.py b/src/kernels/attention/generate_instantiation_cu.py index b9cd8cb4..14c03766 100755 --- a/src/kernels/attention/generate_instantiation_cu.py +++ b/src/kernels/attention/generate_instantiation_cu.py @@ -20,8 +20,8 @@ SM80_MHA_KERNEL_TEMPLATE = """ -#include "sm80_mha_launch.cuh" // IWYU pragma: export -#include "mha_params.h" // IWYU pragma: export +#include "device/sm80_mha_launch.cuh" // IWYU pragma: export +#include "mha_params.h" // IWYU pragma: export namespace llm {{ @@ -39,8 +39,8 @@ """ SM120_MHA_KERNEL_TEMPLATE = """ -#include "sm120_fmha_launch.cuh" // IWYU pragma: export -#include "mha_params.h" // IWYU pragma: export +#include "device/sm120_fmha_launch.cuh" // IWYU pragma: export +#include "mha_params.h" // IWYU pragma: export namespace llm {{ @@ -58,8 +58,8 @@ """ MLA_KERNEL_TEMPLATE = """ -#include "sm80_mla_launch.cuh" // IWYU pragma: export -#include "mla_params.h" // IWYU pragma: export +#include "device/sm80_mla_launch.cuh" // IWYU pragma: export +#include "mla_params.h" // IWYU pragma: export namespace llm {{ diff --git a/src/kernels/attention/attn_combine_kernel.cuh b/src/kernels/attention/kernel/attn_combine_kernel.cuh similarity index 99% rename from src/kernels/attention/attn_combine_kernel.cuh rename to src/kernels/attention/kernel/attn_combine_kernel.cuh index 2ad749ed..973af9a2 100644 --- a/src/kernels/attention/attn_combine_kernel.cuh +++ b/src/kernels/attention/kernel/attn_combine_kernel.cuh @@ -6,8 +6,8 @@ #include #include -#include "fast_cast.cuh" -#include "safe_copy.h" +#include "common/fast_cast.cuh" +#include "common/safe_copy.h" namespace llm { diff --git a/src/kernels/attention/sm120_kernel_fmha_ws.cuh b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh similarity index 99% rename from src/kernels/attention/sm120_kernel_fmha_ws.cuh rename to src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh index 46254a75..4e6dfda2 100644 --- a/src/kernels/attention/sm120_kernel_fmha_ws.cuh +++ b/src/kernels/attention/kernel/sm120_kernel_fmha_ws.cuh @@ -8,11 +8,11 @@ #include #include -#include "gather_tensor.hpp" -#include "layout_convertor.h" -#include "mask.h" +#include "common/gather_tensor.h" +#include "common/layout_convertor.h" +#include "common/mask.h" +#include "common/online_softmax.cuh" #include "mha_params.h" -#include "online_softmax.cuh" namespace llm { diff --git a/src/kernels/attention/sm80_kernel_mha.cuh b/src/kernels/attention/kernel/sm80_kernel_mha.cuh similarity index 98% rename from src/kernels/attention/sm80_kernel_mha.cuh rename to src/kernels/attention/kernel/sm80_kernel_mha.cuh index f9e8f7af..f3bc5850 100644 --- a/src/kernels/attention/sm80_kernel_mha.cuh +++ b/src/kernels/attention/kernel/sm80_kernel_mha.cuh @@ -6,11 +6,11 @@ #include #include -#include "gather_tensor.hpp" -#include "layout_convertor.h" -#include "mask.h" +#include "common/gather_tensor.h" +#include "common/layout_convertor.h" +#include "common/mask.h" +#include "common/online_softmax.cuh" #include "mha_params.h" -#include "online_softmax.cuh" namespace llm { diff --git a/src/kernels/attention/sm80_kernel_mla.cuh b/src/kernels/attention/kernel/sm80_kernel_mla.cuh similarity index 99% rename from src/kernels/attention/sm80_kernel_mla.cuh rename to src/kernels/attention/kernel/sm80_kernel_mla.cuh index 8f65e58d..1f3a79cb 100644 --- a/src/kernels/attention/sm80_kernel_mla.cuh +++ b/src/kernels/attention/kernel/sm80_kernel_mla.cuh @@ -6,9 +6,9 @@ #include #include -#include "gather_tensor.hpp" +#include "common/gather_tensor.h" +#include "common/online_softmax.cuh" #include "mla_params.h" -#include "online_softmax.cuh" namespace llm { diff --git a/src/kernels/attention/mha_params.h b/src/kernels/attention/mha_params.h index a5b86dae..9c68a81b 100644 --- a/src/kernels/attention/mha_params.h +++ b/src/kernels/attention/mha_params.h @@ -4,7 +4,7 @@ #include #include -#include "fast_math.h" +#include "common/fast_math.h" namespace llm { // common params for attention kernels diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h index fdb0113a..2128ff35 100644 --- a/src/kernels/attention/mla_params.h +++ b/src/kernels/attention/mla_params.h @@ -4,7 +4,7 @@ #include #include -#include "fast_math.h" +#include "common/fast_math.h" namespace llm { // common params for attention kernels diff --git a/src/kernels/attention/attn_combine_kernel_test.cu b/src/kernels/attention/tests/attn_combine_kernel_test.cu similarity index 98% rename from src/kernels/attention/attn_combine_kernel_test.cu rename to src/kernels/attention/tests/attn_combine_kernel_test.cu index 2d6cb642..445cf2ac 100644 --- a/src/kernels/attention/attn_combine_kernel_test.cu +++ b/src/kernels/attention/tests/attn_combine_kernel_test.cu @@ -4,8 +4,8 @@ #include -#include "attn_combine_kernel.cuh" // IWYU pragma: keep -#include "static_dispatch.h" +#include "common/static_dispatch.h" +#include "kernel/attn_combine_kernel.cuh" namespace llm { @@ -201,4 +201,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(32, 64, 80, 128, 192, 256) // head_dim )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_cpu.h b/src/kernels/attention/tests/mha_cpu.h similarity index 98% rename from src/kernels/attention/mha_cpu.h rename to src/kernels/attention/tests/mha_cpu.h index 2d0c120b..6cd42ce1 100644 --- a/src/kernels/attention/mha_cpu.h +++ b/src/kernels/attention/tests/mha_cpu.h @@ -4,11 +4,11 @@ #include #include +#include +#include #include #include "common/range.h" -#include "cute/layout.hpp" -#include "cute/stride.hpp" namespace llm { // query/out: [q_seq_len, n_head, head_dim] @@ -126,4 +126,4 @@ inline void mha(torch::Tensor query, } } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_cpu_test.cpp b/src/kernels/attention/tests/mha_cpu_test.cpp similarity index 99% rename from src/kernels/attention/mha_cpu_test.cpp rename to src/kernels/attention/tests/mha_cpu_test.cpp index d569bfd7..4aa158ec 100644 --- a/src/kernels/attention/mha_cpu_test.cpp +++ b/src/kernels/attention/tests/mha_cpu_test.cpp @@ -85,4 +85,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(32) // head_dim )); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_ref.h b/src/kernels/attention/tests/mha_ref.h similarity index 99% rename from src/kernels/attention/mha_ref.h rename to src/kernels/attention/tests/mha_ref.h index 865293b2..5000f4f7 100644 --- a/src/kernels/attention/mha_ref.h +++ b/src/kernels/attention/tests/mha_ref.h @@ -167,4 +167,4 @@ inline torch::Tensor mha_varlen_ref( } return torch::cat(out_list, /*dim=*/0); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/tests/mla_ref.h similarity index 99% rename from src/kernels/attention/mla_ref.h rename to src/kernels/attention/tests/mla_ref.h index 94124662..e89eb6ca 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/tests/mla_ref.h @@ -119,4 +119,4 @@ inline torch::Tensor mla_varlen_ref( return torch::cat(out_list, /*dim=*/0); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/sm120_fmha_test.cu b/src/kernels/attention/tests/sm120_fmha_test.cu similarity index 98% rename from src/kernels/attention/sm120_fmha_test.cu rename to src/kernels/attention/tests/sm120_fmha_test.cu index d3aeb905..f76a07bc 100644 --- a/src/kernels/attention/sm120_fmha_test.cu +++ b/src/kernels/attention/tests/sm120_fmha_test.cu @@ -4,9 +4,9 @@ #include #include +#include "device/sm120_fmha_launch.cuh" #include "mha_params.h" -#include "mha_ref.h" -#include "sm120_fmha_launch.cuh" +#include "tests/mha_ref.h" namespace llm { #define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ diff --git a/src/kernels/attention/sm120_tma_block_load_test.cu b/src/kernels/attention/tests/sm120_tma_block_load_test.cu similarity index 98% rename from src/kernels/attention/sm120_tma_block_load_test.cu rename to src/kernels/attention/tests/sm120_tma_block_load_test.cu index 78a6117f..491e304d 100644 --- a/src/kernels/attention/sm120_tma_block_load_test.cu +++ b/src/kernels/attention/tests/sm120_tma_block_load_test.cu @@ -3,11 +3,11 @@ #include #include +#include #include -#include "cute/swizzle_layout.hpp" -#include "gather_tma_copy.h" -#include "gather_tma_tensor.hpp" +#include "common/gather_tma_copy.h" +#include "common/gather_tma_tensor.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_mha_pagedkv_test.cu b/src/kernels/attention/tests/sm80_mha_pagedkv_test.cu similarity index 98% rename from src/kernels/attention/sm80_mha_pagedkv_test.cu rename to src/kernels/attention/tests/sm80_mha_pagedkv_test.cu index 454907ff..576b9238 100644 --- a/src/kernels/attention/sm80_mha_pagedkv_test.cu +++ b/src/kernels/attention/tests/sm80_mha_pagedkv_test.cu @@ -4,10 +4,10 @@ #include +#include "common/static_dispatch.h" +#include "device/sm80_mha_dispatch.cuh" #include "mha_params.h" -#include "mha_ref.h" -#include "sm80_mha_dispatch.cuh" -#include "static_dispatch.h" +#include "tests/mha_ref.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_mha_test.cu b/src/kernels/attention/tests/sm80_mha_test.cu similarity index 96% rename from src/kernels/attention/sm80_mha_test.cu rename to src/kernels/attention/tests/sm80_mha_test.cu index 7880db1b..8f554117 100644 --- a/src/kernels/attention/sm80_mha_test.cu +++ b/src/kernels/attention/tests/sm80_mha_test.cu @@ -4,11 +4,11 @@ #include #include +#include "common/static_dispatch.h" +#include "device/sm80_mha_dispatch.cuh" +#include "device/sm80_mha_launch.cuh" // IWYU pragma: keep #include "mha_params.h" -#include "mha_ref.h" -#include "sm80_mha_dispatch.cuh" -#include "sm80_mha_launch.cuh" // IWYU pragma: keep -#include "static_dispatch.h" +#include "tests/mha_ref.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_mla_pagedkv_test.cu b/src/kernels/attention/tests/sm80_mla_pagedkv_test.cu similarity index 98% rename from src/kernels/attention/sm80_mla_pagedkv_test.cu rename to src/kernels/attention/tests/sm80_mla_pagedkv_test.cu index 8dc1313e..93b0ebca 100644 --- a/src/kernels/attention/sm80_mla_pagedkv_test.cu +++ b/src/kernels/attention/tests/sm80_mla_pagedkv_test.cu @@ -2,10 +2,11 @@ #include #include -#include "cute/layout.hpp" +#include + +#include "device/sm80_mla_dispatch.cuh" #include "mla_params.h" -#include "mla_ref.h" -#include "sm80_mla_dispatch.cuh" +#include "tests/mla_ref.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/sm80_mla_test.cu b/src/kernels/attention/tests/sm80_mla_test.cu similarity index 98% rename from src/kernels/attention/sm80_mla_test.cu rename to src/kernels/attention/tests/sm80_mla_test.cu index 89564b8f..389cb3ff 100644 --- a/src/kernels/attention/sm80_mla_test.cu +++ b/src/kernels/attention/tests/sm80_mla_test.cu @@ -4,9 +4,9 @@ #include #include +#include "device/sm80_mla_dispatch.cuh" #include "mla_params.h" -#include "mla_ref.h" -#include "sm80_mla_dispatch.cuh" +#include "tests/mla_ref.h" namespace llm { #define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \ diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt index 01fc7e5f..7a51f6be 100644 --- a/src/kernels/gemm/CMakeLists.txt +++ b/src/kernels/gemm/CMakeLists.txt @@ -4,11 +4,13 @@ include(cc_test) cc_library( NAME gemm.kernels + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} HDRS - sm80_collective_grouped_gemm.cuh - sm80_collective_epilogue.cuh - sm80_grouped_gemm_launch.cuh - tile_scheduler.cuh + collective/sm80_collective_grouped_gemm.cuh + collective/sm80_collective_epilogue.cuh + device/sm80_grouped_gemm_launch.cuh + common/tile_scheduler.cuh DEPS cutlass ) @@ -16,8 +18,10 @@ cc_library( cc_test( NAME tile_scheduler_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - tile_scheduler_test.cu + tests/tile_scheduler_test.cu DEPS :gtest_main absl::random_random @@ -27,8 +31,10 @@ cc_test( cc_test( NAME gemm_kernel_test + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} SRCS - sm80_grouped_gemm_test.cu + tests/sm80_grouped_gemm_test.cu DEPS :gemm.kernels :gtest_main diff --git a/src/kernels/gemm/sm80_collective_epilogue.cuh b/src/kernels/gemm/collective/sm80_collective_epilogue.cuh similarity index 98% rename from src/kernels/gemm/sm80_collective_epilogue.cuh rename to src/kernels/gemm/collective/sm80_collective_epilogue.cuh index ac84ae00..2aeaf124 100644 --- a/src/kernels/gemm/sm80_collective_epilogue.cuh +++ b/src/kernels/gemm/collective/sm80_collective_epilogue.cuh @@ -7,8 +7,8 @@ #include #include -#include "fast_cast.cuh" -#include "safe_copy.hpp" +#include "common/fast_cast.cuh" +#include "common/safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/gemm/sm80_collective_grouped_gemm.cuh b/src/kernels/gemm/collective/sm80_collective_grouped_gemm.cuh similarity index 99% rename from src/kernels/gemm/sm80_collective_grouped_gemm.cuh rename to src/kernels/gemm/collective/sm80_collective_grouped_gemm.cuh index 240fc6dd..4495c53f 100644 --- a/src/kernels/gemm/sm80_collective_grouped_gemm.cuh +++ b/src/kernels/gemm/collective/sm80_collective_grouped_gemm.cuh @@ -6,7 +6,7 @@ #include #include -#include "safe_copy.hpp" +#include "common/safe_copy.h" namespace llm { using namespace cute; diff --git a/src/kernels/attention/fast_cast.cuh b/src/kernels/gemm/common/fast_cast.cuh similarity index 98% rename from src/kernels/attention/fast_cast.cuh rename to src/kernels/gemm/common/fast_cast.cuh index 2a158b01..b574b1e2 100644 --- a/src/kernels/attention/fast_cast.cuh +++ b/src/kernels/gemm/common/fast_cast.cuh @@ -62,4 +62,4 @@ CUTE_DEVICE void fast_cast(const FragmentS& src, FragmentD& dst) { detail::type_cast::cast(src, dst); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/gemm/fast_math.h b/src/kernels/gemm/common/fast_math.h similarity index 100% rename from src/kernels/gemm/fast_math.h rename to src/kernels/gemm/common/fast_math.h diff --git a/src/kernels/gemm/gather_tensor.hpp b/src/kernels/gemm/common/gather_tensor.h similarity index 100% rename from src/kernels/gemm/gather_tensor.hpp rename to src/kernels/gemm/common/gather_tensor.h diff --git a/src/kernels/gemm/safe_copy.hpp b/src/kernels/gemm/common/safe_copy.h similarity index 100% rename from src/kernels/gemm/safe_copy.hpp rename to src/kernels/gemm/common/safe_copy.h diff --git a/src/kernels/gemm/static_dispatch.h b/src/kernels/gemm/common/static_dispatch.h similarity index 100% rename from src/kernels/gemm/static_dispatch.h rename to src/kernels/gemm/common/static_dispatch.h diff --git a/src/kernels/gemm/tile_scheduler.cuh b/src/kernels/gemm/common/tile_scheduler.cuh similarity index 100% rename from src/kernels/gemm/tile_scheduler.cuh rename to src/kernels/gemm/common/tile_scheduler.cuh diff --git a/src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh b/src/kernels/gemm/device/sm80_grouped_gemm_dispatch.cuh similarity index 96% rename from src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh rename to src/kernels/gemm/device/sm80_grouped_gemm_dispatch.cuh index 815e6190..3edd22a7 100644 --- a/src/kernels/gemm/sm80_grouped_gemm_dispatch.cuh +++ b/src/kernels/gemm/device/sm80_grouped_gemm_dispatch.cuh @@ -3,9 +3,8 @@ #include #include -#include "huggingface/safetensors.h" -#include "sm80_grouped_gemm_launch.cuh" -#include "static_dispatch.h" +#include "common/static_dispatch.h" +#include "device/sm80_grouped_gemm_launch.cuh" namespace llm { using namespace cute; diff --git a/src/kernels/gemm/sm80_grouped_gemm_launch.cuh b/src/kernels/gemm/device/sm80_grouped_gemm_launch.cuh similarity index 91% rename from src/kernels/gemm/sm80_grouped_gemm_launch.cuh rename to src/kernels/gemm/device/sm80_grouped_gemm_launch.cuh index de1929f4..d9eaaf57 100644 --- a/src/kernels/gemm/sm80_grouped_gemm_launch.cuh +++ b/src/kernels/gemm/device/sm80_grouped_gemm_launch.cuh @@ -6,10 +6,10 @@ #include #include -#include "sm80_collective_epilogue.cuh" -#include "sm80_collective_grouped_gemm.cuh" -#include "sm80_kernel_grouped_gemm.cuh" -#include "tile_scheduler.cuh" +#include "collective/sm80_collective_epilogue.cuh" +#include "collective/sm80_collective_grouped_gemm.cuh" +#include "common/tile_scheduler.cuh" +#include "kernel/sm80_kernel_grouped_gemm.cuh" namespace llm { diff --git a/src/kernels/gemm/sm80_kernel_grouped_gemm.cuh b/src/kernels/gemm/kernel/sm80_kernel_grouped_gemm.cuh similarity index 99% rename from src/kernels/gemm/sm80_kernel_grouped_gemm.cuh rename to src/kernels/gemm/kernel/sm80_kernel_grouped_gemm.cuh index b21e060d..87f0e207 100644 --- a/src/kernels/gemm/sm80_kernel_grouped_gemm.cuh +++ b/src/kernels/gemm/kernel/sm80_kernel_grouped_gemm.cuh @@ -6,7 +6,7 @@ #include #include -#include "gather_tensor.hpp" +#include "common/gather_tensor.h" namespace llm { diff --git a/src/kernels/gemm/sm80_grouped_gemm_test.cu b/src/kernels/gemm/tests/sm80_grouped_gemm_test.cu similarity index 98% rename from src/kernels/gemm/sm80_grouped_gemm_test.cu rename to src/kernels/gemm/tests/sm80_grouped_gemm_test.cu index dc744ee3..808e33b6 100644 --- a/src/kernels/gemm/sm80_grouped_gemm_test.cu +++ b/src/kernels/gemm/tests/sm80_grouped_gemm_test.cu @@ -3,8 +3,8 @@ #include -#include "sm80_grouped_gemm_dispatch.cuh" // IWYU pragma: keep -#include "static_dispatch.h" +#include "common/static_dispatch.h" +#include "device/sm80_grouped_gemm_dispatch.cuh" namespace llm { diff --git a/src/kernels/gemm/tile_scheduler_test.cu b/src/kernels/gemm/tests/tile_scheduler_test.cu similarity index 99% rename from src/kernels/gemm/tile_scheduler_test.cu rename to src/kernels/gemm/tests/tile_scheduler_test.cu index 0dba4cbc..81e6ebe3 100644 --- a/src/kernels/gemm/tile_scheduler_test.cu +++ b/src/kernels/gemm/tests/tile_scheduler_test.cu @@ -2,7 +2,7 @@ #include -#include "tile_scheduler.cuh" +#include "common/tile_scheduler.cuh" namespace llm { diff --git a/third_party/cutlass b/third_party/cutlass index 4ec8dd93..b2fd3b08 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 4ec8dd93c469e4903b6ff65de02cec089c97944c +Subproject commit b2fd3b08880e8b1395c80c4cb46cf1bafb74b7c9