Skip to content

Commit 2c32009

Browse files
authored
refactor: stop build flash_infer kernel (#386)
1 parent 90f5415 commit 2c32009

File tree

5 files changed

+3
-13
lines changed

5 files changed

+3
-13
lines changed

scalellm/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ pybind_extension(
1818
DEPS
1919
:llm_handler
2020
:marlin.kernels
21-
:flash_infer.kernels
2221
torch
2322
torch_python
2423
absl::strings

scalellm/csrc/kernels.cu

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <pybind11/pybind11.h>
22
#include <torch/extension.h>
33

4-
#include "kernels/attention/flash_infer/attention_wrapper.h"
54
#include "kernels/quantization/marlin.h"
65

76
namespace llm::csrc {
@@ -53,16 +52,6 @@ void init_kernels(py::module_& m) {
5352
py::arg("q_weight"),
5453
py::arg("out"),
5554
py::arg("num_bits"));
56-
57-
// flashinfer kernels
58-
py::class_<flashinfer::BatchPrefillWrapper>(m, "BatchPrefillWrapper")
59-
.def(py::init<bool>())
60-
.def("plan", &flashinfer::BatchPrefillWrapper::Plan)
61-
.def("is_cuda_graph_enabled",
62-
&flashinfer::BatchPrefillWrapper::IsCUDAGraphEnabled)
63-
.def("update_page_locked_buffer_size",
64-
&flashinfer::BatchPrefillWrapper::UpdatePageLockedBufferSize)
65-
.def("run", &flashinfer::BatchPrefillWrapper::Run);
6655
}
6756

6857
} // namespace llm::csrc

src/kernels/attention/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,5 @@ cc_binary(
8080
)
8181

8282
add_subdirectory(flash_attn)
83-
add_subdirectory(flash_infer)
83+
# add_subdirectory(flash_infer)
8484
add_subdirectory(tools)

tests/kernels/attention/flash_infer_kv_fp8_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import scalellm._C.kernels as kernels # type: ignore
88

99

10+
@pytest.mark.skip(reason="Not implemented")
1011
@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]])
1112
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
1213
@pytest.mark.parametrize("head_size", [64, 128, 256])

tests/kernels/attention/flash_infer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import scalellm._C.kernels as kernels # type: ignore
88

99

10+
@pytest.mark.skip(reason="Not implemented")
1011
@pytest.mark.parametrize("seq_lens", [[(1, 100)], [(100, 100)], [(1, 100), (15, 15), (111, 234), (1000, 10000)]])
1112
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
1213
@pytest.mark.parametrize("head_size", [64, 128, 256])

0 commit comments

Comments
 (0)