Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 57 additions & 39 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ include(cc_test)
cc_library(
NAME
attention.template
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
HDRS
fast_math.h
layout_convertor.h
Expand Down Expand Up @@ -41,13 +43,13 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu")
cc_library(
NAME
attention.kernels
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
HDRS
attn_api.h
SRCS
attn_api.cpp
${GENERATED_SRC_FILES}
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
DEPS
:attention.template
glog::glog
Expand All @@ -56,9 +58,11 @@ cc_library(
cc_test(
NAME
mha_kernel_test
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
# sm80_mha_test.cu
sm80_mha_pagedkv_test.cu
tests/sm80_mha_pagedkv_test.cu
DEPS
:attention.kernels
absl::random_random
Expand All @@ -69,77 +73,91 @@ cc_test(
cc_test(
NAME
sm120_fmha_kernel_test
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
sm120_fmha_test.cu
tests/sm120_fmha_test.cu
DEPS
:attention.template
absl::random_random
:gtest_main
absl::random_random
torch
)

cc_test(
NAME
mla_kernel_test
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
sm80_mla_test.cu
sm80_mla_pagedkv_test.cu
tests/sm80_mla_test.cu
tests/sm80_mla_pagedkv_test.cu
DEPS
:attention.kernels
absl::random_random
:gtest_main
absl::random_random
torch
)

cc_test(
NAME
attn_combine_kernel_test
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
attn_combine_kernel_test.cu
tests/attn_combine_kernel_test.cu
DEPS
:attention.template
absl::random_random
:gtest_main
torch
)

nvbench_binary(
NAME
sm80_mha_bench
SRCS
sm80_mha_bench.cu
DEPS
:attention.template
)

nvbench_binary(
NAME
sm80_mha_pagedkv_bench
SRCS
sm80_mha_pagedkv_bench.cu
DEPS
absl::random_random
:attention.template
)

nvbench_binary(
NAME
sm80_mla_bench
SRCS
mla_sm80_bench.cu
DEPS
:attention.template
)

cc_test(
NAME
tma_test
INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}
SRCS
sm120_tma_block_load_test.cu
tests/sm120_tma_block_load_test.cu
DEPS
:gtest_main
:cutlass
torch
)

add_subdirectory(tools)
# nvbench_binary(
# NAME
# sm80_mha_bench
# INCLUDES
# ${CMAKE_CURRENT_SOURCE_DIR}
# SRCS
# sm80_mha_bench.cu
# DEPS
# :attention.template
# )

# nvbench_binary(
# NAME
# sm80_mha_pagedkv_bench
# INCLUDES
# ${CMAKE_CURRENT_SOURCE_DIR}
# SRCS
# sm80_mha_pagedkv_bench.cu
# DEPS
# absl::random_random
# :attention.template
# )

# nvbench_binary(
# NAME
# sm80_mla_bench
# INCLUDES
# ${CMAKE_CURRENT_SOURCE_DIR}
# SRCS
# mla_sm80_bench.cu
# DEPS
# :attention.template
# )

# add_subdirectory(tools)
7 changes: 4 additions & 3 deletions src/kernels/attention/attn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

#include <ATen/cuda/CUDAContext.h>

#include "cute/layout.hpp"
#include <cute/layout.hpp>

#include "common/static_dispatch.h"
#include "device/sm80_mha_dispatch.cuh"
#include "mha_params.h"
#include "sm80_mha_dispatch.cuh"
#include "static_dispatch.h"

namespace llm {
using namespace cute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/safe_copy.h"

namespace llm {
using namespace cute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <cute/tensor.hpp>
#include <cutlass/pipeline/pipeline.hpp>

#include "fast_cast.cuh"
#include "layout_convertor.h"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/layout_convertor.h"
#include "common/safe_copy.h"
#include "sm120_collective_load_cpasync_ws.cuh"

namespace llm {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "safe_copy.h"
#include "common/safe_copy.h"

namespace llm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/safe_copy.h"

namespace llm {
using namespace cute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "layout_convertor.h"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/layout_convertor.h"
#include "common/safe_copy.h"

namespace llm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "layout_convertor.h"
#include "mask.h"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/layout_convertor.h"
#include "common/mask.h"
#include "common/safe_copy.h"

namespace llm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "safe_copy.h"
#include "common/fast_cast.cuh"
#include "common/safe_copy.h"

namespace llm {
using namespace cute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp
#pragma once

#include "cute/layout.hpp"
#include "cute/layout_composed.hpp"
#include "cute/tensor.hpp"
#include <cute/layout.hpp>
#include <cute/layout_composed.hpp>
#include <cute/tensor.hpp>
namespace llm {

using namespace cute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ struct GatherArithmeticTupleIterator {
const Transform& transform)
: coord_(coord), transform_(transform) {}

CUTE_HOST_DEVICE constexpr auto coord() const {
CUTE_HOST_DEVICE constexpr auto operator*() const {
// apply the transform to the coordinate
return transform_(coord_);
}

CUTE_HOST_DEVICE constexpr auto operator*() const { return coord(); }

template <class Coord>
CUTE_HOST_DEVICE constexpr auto operator+(const Coord& c) const {
auto coord = coord_ + c;
Expand Down Expand Up @@ -67,15 +65,15 @@ template <class ArithTuple, class Transform>
CUTE_HOST_DEVICE void print(
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
printf("GatherArithTuple");
print(iter.coord());
print(*iter);
}

#if !defined(__CUDACC_RTC__)
template <class ArithTuple, class Transform>
CUTE_HOST std::ostream& operator<<(
std::ostream& os,
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
return os << "GatherArithTuple" << iter.coord();
return os << "GatherArithTuple" << *iter;
}
#endif

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ struct OnlineSoftmax {
}
};

} // namespace llm
} // namespace llm
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#pragma once

#include <cute/atom/mma_atom.hpp>
#include <cute/config.hpp>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "cute/config.hpp"
#include "cute/layout.hpp"

namespace cute {

namespace detail {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ namespace llm {
} \
}()

} // namespace llm
} // namespace llm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>

#include "static_dispatch.h"
#include "common/static_dispatch.h"

namespace llm {
// forward declaration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "sm120_collective_epilogue.cuh"
#include "sm120_collective_fmha_mainloop_ws.cuh"
#include "sm120_kernel_fmha_ws.cuh"
#include "tile_scheduler.cuh"
#include "collective/sm120_collective_epilogue.cuh"
#include "collective/sm120_collective_fmha_mainloop_ws.cuh"
#include "common/tile_scheduler.cuh"
#include "kernel/sm120_kernel_fmha_ws.cuh"

namespace llm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>

#include "static_dispatch.h"
#include "common/static_dispatch.h"

namespace llm {
// forward declaration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "sm80_collective_epilogue.cuh"
#include "sm80_collective_mha.cuh"
#include "sm80_kernel_mha.cuh"
#include "tile_scheduler.cuh"
#include "collective/sm80_collective_epilogue.cuh"
#include "collective/sm80_collective_mha.cuh"
#include "common/tile_scheduler.cuh"
#include "kernel/sm80_kernel_mha.cuh"

namespace llm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>

#include "common/static_dispatch.h"
#include "sm80_mla_launch.cuh"
#include "static_dispatch.h"

namespace llm {

Expand Down
Loading