Skip to content

Commit 740ac12

Browse files
authored
Feat: wait kernel (#41)
1 parent 09c1ffe commit 740ac12

File tree

10 files changed

+338
-21
lines changed

10 files changed

+338
-21
lines changed

CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ cmake_minimum_required(VERSION 3.22 FATAL_ERROR)
33
project(af LANGUAGES C CXX)
44

55
set(CMAKE_CXX_STANDARD 17)
6-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ")
6+
execute_process(COMMAND ${Python_EXECUTABLE}
7+
-c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
8+
OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
9+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ")
10+
11+
12+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ")
713

814
# import pytorch library
915
find_package (Python COMPONENTS Interpreter Development)
@@ -19,7 +25,9 @@ find_package(Torch REQUIRED CONFIG)
1925
message("MY TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}")
2026
message("MY CUDA_INCLUDE_DIRS ${CUDA_INCLUDE_DIRS}")
2127
include_directories(${TORCH_INCLUDE_DIRS})
28+
# Save ABI setting before adding TORCH_CXX_FLAGS (which might override it)
2229
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
30+
# Ensure ABI setting is preserved after TORCH_CXX_FLAGS
2331

2432
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
2533
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)

fserver/csrc/kernel.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <torch/extension.h>
2+
#include <cuda.h>
3+
#include <cuda_runtime.h>
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <cuda_runtime.h>
7+
8+
9+
torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index);
10+
void write_flag(torch::Tensor flag, torch::Tensor seq);
11+
void wait_flag(torch::Tensor flag, torch::Tensor seq);
12+
void seq_add_one(torch::Tensor seq);
13+
14+
void pybind_kernel(py::module &m){
15+
// StepMesh utils
16+
m.def("map_pinned_tensor", &map_pinned_tensor, py::arg("tensor"), py::arg("device_index"));
17+
m.def("write_flag", &write_flag, py::arg("flag"), py::arg("seq"));
18+
m.def("wait_flag", &wait_flag, py::arg("flag"), py::arg("seq"));
19+
m.def("seq_add_one", &seq_add_one, py::arg("seq"));
20+
}

fserver/csrc/ops.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */
22

33

4-
#include "./util.h"
4+
#include "./util.hpp"
55

66
#include "./public.hpp"
7-
#include "./private.hpp"
7+
#ifdef DMLC_USE_CUDA
8+
#include "./private.hpp"
9+
#include "./kernel.hpp"
10+
#endif
811

912
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1013
pybind_public(m);
14+
#ifdef DMLC_USE_CUDA
1115
pybind_private(m);
16+
pybind_kernel(m);
17+
#endif
1218
}

fserver/csrc/private.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */
22

3-
#include "./util.h"
3+
#include "./util.hpp"
44
#include "./public.hpp"
55
#include <future>
66
#ifdef DMLC_USE_CUDA

fserver/csrc/public.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
#include <chrono>
1717

18-
#include "./util.h"
18+
#include "./util.hpp"
1919

2020
#ifndef PUBLIC_OPS_
2121
#define PUBLIC_OPS_
@@ -121,7 +121,8 @@ void respond_vec(torch::Tensor& ret_buffer,
121121
int push_pull(std::vector<torch::Tensor>& push_tensors,
122122
std::vector<uint64_t>& push_keys,
123123
std::vector<torch::Tensor>& pull_tensors,
124-
std::vector<uint64_t>& pull_keys) {
124+
std::vector<uint64_t>& pull_keys,
125+
bool need_event = true) {
125126

126127
PS_CHECK_EQ(push_tensors.size(), push_keys.size());
127128
PS_CHECK_EQ(pull_tensors.size(), pull_keys.size());
@@ -138,7 +139,7 @@ int push_pull(std::vector<torch::Tensor>& push_tensors,
138139
static_cast<uint64_t>(pull_keys[i]), std::move(pull_tensors[i].detach())
139140
};
140141
}
141-
return fworker_->ZBatchPushPull(push_batch, pull_batch);
142+
return fworker_->ZBatchPushPull(push_batch, pull_batch, need_event);
142143
}
143144

144145
void wait(int handler, uint64_t timeout_ms = 1000) {
@@ -237,7 +238,7 @@ uint64_t get_nanosecond() {
237238

238239

239240
void pybind_public(py::module &m){
240-
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
241+
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
241242
m.def("stop", &stop, py::call_guard<py::gil_scoped_release>());
242243

243244
m.def("register_recv_buffer",
@@ -250,6 +251,7 @@ void pybind_public(py::module &m){
250251
py::arg("push_keys"),
251252
py::arg("pull_tensors"),
252253
py::arg("pull_keys"),
254+
py::arg("need_event") = true,
253255
py::call_guard<py::none>());
254256
m.def("wait", &wait,
255257
py::arg("handler"),

fserver/csrc/wait_kernel.cu

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include <torch/extension.h>
2+
#include <cuda.h>
3+
#include <cuda_runtime.h>
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <cuda_runtime.h>
7+
8+
__global__ void write_flag_kernel(int64_t* flag, int64_t* seq) {
9+
int64_t seq_value = seq[0];
10+
if (threadIdx.x == 0) {
11+
flag[0] = seq_value;
12+
// 写入后执行 system fence,确保写入对所有线程和 CPU 可见
13+
}
14+
__threadfence_system();
15+
}
16+
17+
__global__ void wait_flag_kernel(int64_t* flag, int64_t* seq) {
18+
if (threadIdx.x == 0) {
19+
// Mark pointer volatile so we reload host-written values each iteration.
20+
volatile int64_t* flag_ptr = flag, *seq_ptr = seq;
21+
int64_t flag_value = flag_ptr[0];
22+
int64_t seq_value = seq_ptr[0];
23+
while (flag_value < seq_value) {
24+
__nanosleep(128);
25+
flag_value = flag_ptr[0];
26+
}
27+
}
28+
}
29+
30+
__global__ void seq_add_one_kernel(int64_t* seq) {
31+
if (threadIdx.x == 0) {
32+
seq[0]++;
33+
}
34+
__threadfence_system();
35+
}
36+
37+
static void check_cuda(cudaError_t err, const char* msg) {
38+
TORCH_CHECK(err == cudaSuccess, msg, ": ", cudaGetErrorString(err));
39+
}
40+
41+
torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index) {
42+
TORCH_CHECK(tensor.is_pinned(), "tensor must be pinned");
43+
void* host_ptr = tensor.data_ptr();
44+
void* device_ptr = nullptr;
45+
check_cuda(cudaHostGetDevicePointer(&device_ptr, host_ptr, 0),
46+
"cudaHostGetDevicePointer failed");
47+
auto options = tensor.options().device(torch::kCUDA, device_index);
48+
auto sizes = tensor.sizes();
49+
auto strides = tensor.strides();
50+
return torch::from_blob(device_ptr, sizes, strides, [](void*){}, options);
51+
}
52+
53+
void write_flag(torch::Tensor flag, torch::Tensor seq) {
54+
TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor");
55+
auto stream = at::cuda::getCurrentCUDAStream(flag.device().index());
56+
write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr<int64_t>(), seq.data_ptr<int64_t>());
57+
check_cuda(cudaGetLastError(), "write_flag_kernel launch failed");
58+
}
59+
60+
void wait_flag(torch::Tensor flag, torch::Tensor seq) {
61+
TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor");
62+
auto stream = at::cuda::getCurrentCUDAStream(flag.device().index());
63+
wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr<int64_t>(), seq.data_ptr<int64_t>());
64+
check_cuda(cudaGetLastError(), "wait_flag_kernel launch failed");
65+
}
66+
67+
void seq_add_one(torch::Tensor seq) {
68+
TORCH_CHECK(seq.is_cuda(), "seq must be a CUDA tensor");
69+
auto stream = at::cuda::getCurrentCUDAStream(seq.device().index());
70+
seq_add_one_kernel<<<1, 1, 0, stream>>>(seq.data_ptr<int64_t>());
71+
check_cuda(cudaGetLastError(), "seq_add_one_kernel launch failed");
72+
}

include/ps/af_tensor_app.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class AFTensorWorker {
9393
* where the pulled tensors and their associated keys will be stored.
9494
* @return An integer indicating the result of the operation.
9595
*/
96-
int ZBatchPushPull(KeyTensorBatch& push_tensors,
97-
KeyTensorBatch& pull_tensors) {
96+
int ZBatchPushPull(KeyTensorBatch& push_tensors, KeyTensorBatch& pull_tensors,
97+
bool need_event = true) {
9898
Backend::Get()->SetDevice(gpu_);
9999
auto server_ranges =
100100
Postoffice::GetWorker(instance_id_)->GetServerKeyRanges();
@@ -130,8 +130,11 @@ class AFTensorWorker {
130130

131131
req.push = push_tensors;
132132
req.pull = pull_tensors;
133-
req.event = GetEvent();
134-
req.event->Record();
133+
req.event = nullptr;
134+
if (need_event) {
135+
req.event = GetEvent();
136+
req.event->Record();
137+
}
135138

136139
PS_VLOG(3) << "ts" << start_ts << " pushpull_queue_ push "
137140
<< pushpull_queue_.Size();

setup.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pathlib import Path
1111

1212
def get_version():
13-
version = '0.0.4.post1'
13+
version = '0.0.5.post1'
1414
# with open('stepkv/version.py', 'r') as fd:
1515
# version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
1616
# fd.read(), re.MULTILINE).group(1)
@@ -62,16 +62,11 @@ def _get_cuda_bare_metal_version(cuda_dir):
6262
if use_cuda:
6363
extra_link += ['-lcuda', '-lcudart']
6464
extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',]
65-
extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_70,code=sm_70',
66-
'--use_fast_math'] + cc_flag
65+
extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_90,code=sm_90', '-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_89,code=sm_89','-gencode', 'arch=compute_90a,code=sm_90a',
66+
'--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag
6767
bare_metal_major, bare_metal_minor = \
6868
_get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
69-
if int(bare_metal_major) >= 11:
70-
cc_flag.append('-gencode')
71-
cc_flag.append('arch=compute_80,code=sm_80')
72-
if int(bare_metal_minor) >= 8 or int(bare_metal_major) >= 12:
73-
cc_flag.append('-gencode')
74-
cc_flag.append('arch=compute_90,code=sm_90')
69+
7570
setup(
7671
name='FServer',
7772
description='A Remote FFN Server Implementation for AF Disaggregation',
@@ -84,6 +79,7 @@ def _get_cuda_bare_metal_version(cuda_dir):
8479
'fserver_lib',
8580
[
8681
__SRC_PATH__ + 'ops.cc',
82+
__SRC_PATH__ + 'wait_kernel.cu',
8783
],
8884
extra_compile_args=extra_compile_args,
8985
extra_link_args=extra_link,

0 commit comments

Comments
 (0)