Skip to content

Commit b8be593

Browse files
committed
temp: add xPUBackend refine source tree
TODO: complete xpubackend registry add more interface for xpu chips export some header files refine
1 parent 740ac12 commit b8be593

File tree

23 files changed

+539
-49
lines changed

23 files changed

+539
-49
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ target_link_libraries(af
6767
${TORCH_PYTHON_LIBRARY})
6868

6969
add_subdirectory(tests)
70+
add_subdirectory(plugins/klx_backend)

fserver/csrc/public.hpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */
2+
3+
#include <dlfcn.h>
4+
25
#include <execinfo.h>
36
#include <stdio.h>
47
#include <signal.h>
@@ -43,6 +46,7 @@ uint64_t handler_counter_ = 0;
4346
std::unordered_map<uint64_t, AFTensorMeta> meta_map_;
4447
std::vector<std::deque<ServerDataBatch>> q_;
4548
std::atomic<uint64_t> q_signal_;
49+
static void* gPluginHandle = nullptr;
4650

4751
void RequestHandler(const AFTensorMeta& req_meta, AFTensorServer* server) {
4852
std::vector<torch::Tensor> tensors;
@@ -165,7 +169,12 @@ void barrier(bool include_server, bool include_worker, bool instrance_barrier=tr
165169
}
166170

167171

168-
void init() {
172+
void init(const std::string& plugin) {
173+
if (!plugin.empty()) {
174+
gPluginHandle = dlopen(plugin.c_str(), RTLD_NOW);
175+
PS_CHECK(gPluginHandle)
176+
<< "can't load plugin:" << plugin << ": " << dlerror();
177+
}
169178

170179
std::string role_str = ps::GetEnv("DMLC_ROLE", "server");
171180
int offset = 0;
@@ -204,14 +213,17 @@ void stop() {
204213
ps::Postoffice::GetWorker(gpu_)->DoBarrier(0,
205214
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
206215
} else if (role_ == Node::SERVER) {
207-
ps::Postoffice::GetServer(gpu_)->DoBarrier(0,
208-
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
216+
ps::Postoffice::GetServer(gpu_)->DoBarrier(
217+
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
209218
} else {
210-
ps::Postoffice::Get()->DoBarrier(0,
211-
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
219+
ps::Postoffice::Get()->DoBarrier(
220+
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
212221
}
213222

214223
ps::Finalize(0, role_, true);
224+
if (gPluginHandle) {
225+
dlclose(gPluginHandle);
226+
}
215227
}
216228

217229
std::vector<int> get_all_handlers(int handler) {
@@ -238,7 +250,8 @@ uint64_t get_nanosecond() {
238250

239251

240252
void pybind_public(py::module &m){
241-
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
253+
m.def("init", &init, py::arg("plugin") = "",
254+
py::call_guard<py::gil_scoped_release>());
242255
m.def("stop", &stop, py::call_guard<py::gil_scoped_release>());
243256

244257
m.def("register_recv_buffer",

include/dmlc/backend_registry.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include "base.h"
4+
#include "ps/backend.h"
5+
6+
namespace dmlc {
7+
8+
template <typename T>
9+
struct STEPMESH_API backend_registry {
10+
backend_registry(const std::string& name) {
11+
ps::Backend::RegisterLazy(name, []() { return new T(); });
12+
}
13+
};
14+
15+
} // namespace dmlc

include/dmlc/base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,6 @@ inline const char *BeginPtr(const std::string &str) {
190190
#define alignof __alignof
191191
#endif
192192

193+
#define STEPMESH_API __attribute__((__visibility__("default")))
194+
193195
#endif // DMLC_BASE_H_

include/dmlc/logging.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,35 @@ inline void InitLogging(const char *argv0) {
7171
// DO NOTHING
7272
}
7373

74+
constexpr const char *getFileName(const char *path) {
75+
auto last = path + strlen(path);
76+
while (*last != '/') {
77+
--last;
78+
}
79+
return ++last;
80+
}
81+
7482
// Always-on checking
75-
#define PS_CHECK(x) \
76-
if (!(x)) \
77-
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \
78-
"failed: " #x \
79-
<< ' '
83+
#define PS_CHECK(x) \
84+
if (!(x)) \
85+
dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \
86+
<< "Check " \
87+
"failed: " #x \
88+
<< ' '
8089
#define PS_CHECK_LT(x, y) PS_CHECK((x) < (y))
8190
#define PS_CHECK_GT(x, y) PS_CHECK((x) > (y))
8291
#define PS_CHECK_LE(x, y) PS_CHECK((x) <= (y))
8392
#define PS_CHECK_GE(x, y) PS_CHECK((x) >= (y))
8493
#define PS_CHECK_EQ(x, y) PS_CHECK((x) == (y))
8594
#define PS_CHECK_NE(x, y) PS_CHECK((x) != (y))
86-
#define PS_CHECK_NOTNULL(x) \
87-
((x) == NULL \
88-
? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', \
95+
#define PS_CHECK_NOTNULL(x) \
96+
((x) == NULL \
97+
? dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \
98+
<< "Check notnull: " #x << ' ', \
8999
(x) : (x)) // NOLINT(*)
90100
// Debug-only checking.
91101
#ifdef NDEBUG
92-
/*
102+
/*
93103
#define DPS_CHECK(x) \
94104
while (false) PS_CHECK(x)
95105
#define DPS_CHECK_LT(x, y) \
@@ -114,12 +124,12 @@ inline void InitLogging(const char *argv0) {
114124
#define DPS_CHECK_NE(x, y) PS_CHECK((x) != (y)) */
115125
#endif // NDEBUG
116126

117-
#define PS_LOG_API dmlc::LogMessage(__FILE__, __LINE__)
127+
#define PS_LOG_API dmlc::LogMessage(dmlc::getFileName(__FILE__), __LINE__)
118128

119129
#define PS_LOG_IF(severity, condition) \
120130
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & PS_LOG_API
121131

122-
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
132+
#define LOG_FATAL dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__)
123133
#define PS_LOG_FATAL LOG_FATAL.stream()
124134
#define LOG_QFATAL LOG_FATAL
125135

include/ps/af_tensor_app.h

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#include "ps/base.h"
2222
#include "ps/hash_table8.hpp"
23-
#include "ps/internal/backend.h"
23+
#include "ps/backend.h"
2424
#include "ps/internal/utils.h"
2525
#include "ps/kv_app.h"
2626

@@ -236,15 +236,16 @@ class AFTensorWorker {
236236
void ZPush_(int ts, const SArray<Key>& keys, const at::Tensor& tensor,
237237
int cmd = 0) {
238238
SArray<char> val;
239-
val.reset(reinterpret_cast<char*>(tensor.data_ptr()),
239+
void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor);
240+
val.reset(reinterpret_cast<char*>(mappedPtr),
240241
tensor.numel() * tensor.itemsize(), [tensor](void*) {});
241242

242243
Message msg;
243244
msg.meta.request = true;
244245
msg.meta.head = cmd;
245246
msg.meta.push = true;
246247
msg.meta.timestamp = ts;
247-
msg.meta.addr = reinterpret_cast<uint64_t>(tensor.data_ptr());
248+
msg.meta.addr = reinterpret_cast<uint64_t>(mappedPtr);
248249
msg.meta.val_len = tensor.numel() * tensor.itemsize();
249250
PS_VLOG(2) << "ZPush_ addr: 0x" << std::hex << msg.meta.addr << std::dec
250251
<< " val_len: " << msg.meta.val_len;
@@ -284,13 +285,14 @@ class AFTensorWorker {
284285

285286
*key.data() = pull_tensors[i * pull_batch_size + index].key;
286287

287-
val.reset(reinterpret_cast<char*>(tensor.data_ptr()),
288+
void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor);
289+
val.reset(reinterpret_cast<char*>(mappedPtr),
288290
tensor.numel() * tensor.itemsize(), [tensor](void*) {});
289291

290292
msg.meta.request = true;
291293
msg.meta.head = cmd;
292294
msg.meta.push = false;
293-
msg.meta.addr = reinterpret_cast<uint64_t>(tensor.data_ptr());
295+
msg.meta.addr = reinterpret_cast<uint64_t>(mappedPtr);
294296
msg.meta.val_len = tensor.numel() * tensor.itemsize();
295297
msg.meta.key = key[0];
296298
msg.meta.is_tensor = 1;
@@ -483,7 +485,8 @@ class AFTensorServer {
483485
res.keys = key;
484486

485487
SArray<char> tensor_val;
486-
tensor_val.reset(reinterpret_cast<char*>(tensors[0].val.data_ptr()),
488+
tensor_val.reset(reinterpret_cast<char*>(
489+
Backend::Get()->GetAccessibleAddr(tensors[0].val)),
487490
tensors[0].val.numel() * tensors[0].val.itemsize(),
488491
[](void*) {});
489492
res.vals = tensor_val;
@@ -506,7 +509,8 @@ class AFTensorServer {
506509
rsp.kv_pair.keys = key;
507510

508511
rsp.kv_pair.vals.reset(
509-
reinterpret_cast<char*>(res_kv.val.data_ptr()),
512+
reinterpret_cast<char*>(
513+
Backend::Get()->GetAccessibleAddr(res_kv.val)),
510514
res_kv.val.numel() * res_kv.val.itemsize(), [](void*) {});
511515

512516
rsp.kv_meta = kv_meta;
@@ -558,7 +562,8 @@ class AFTensorServer {
558562
PS_CHECK_GT(worker_ranks.size(), 0) << "ranks or keys should not be empty";
559563
PS_CHECK_EQ(worker_ranks.size(), keys.size())
560564
<< "rank list and key list have unequal size";
561-
char* buffer_ptr = reinterpret_cast<char*>(tensor.data_ptr());
565+
char* buffer_ptr =
566+
reinterpret_cast<char*>(Backend::Get()->GetAccessibleAddr(tensor));
562567
uint64_t data_size = tensor.numel() * tensor.element_size();
563568
int chunk_size = data_size / worker_ranks.size();
564569
PS_CHECK_EQ(data_size % worker_ranks.size(), 0)
@@ -591,8 +596,14 @@ class AFTensorServer {
591596
.dtype(at::ScalarType(req_meta.dtype))
592597
.memory_format(at::MemoryFormat::Contiguous)
593598
.device(Backend::Get()->GetDevice());
594-
key_tensor.val =
595-
at::from_blob(req_data.vals.data(), req_meta.shape, options);
599+
key_tensor.val = at::from_blob(
600+
Backend::Get()->GetDeviceAddrFromHostPtr(
601+
req_data.vals.data(),
602+
std::accumulate(std::begin(req_meta.shape),
603+
std::end(req_meta.shape),
604+
c10::elementSize(at::ScalarType(req_meta.dtype)),
605+
std::multiplies<uint64_t>())),
606+
req_meta.shape, options);
596607
}
597608
key_tensor.key = req_data.keys[0];
598609
return key_tensor;
Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
#include <string>
1616
#include <unordered_map>
1717
#include <utility>
18+
#include <functional>
1819

1920
#include "dmlc/logging.h"
20-
#include "ps/internal/env.h"
21+
#include "ps/env.h"
22+
#include "base.h"
2123

2224
namespace ps {
2325

@@ -26,7 +28,7 @@ enum { BACKEND_OK = 0, BACKEND_FAILED = -1 };
2628
/**
2729
* \brief Abstract Backend Class
2830
*/
29-
class Backend {
31+
class STEPMESH_API Backend {
3032
public:
3133
/**
3234
* \brief Set device index for current thread
@@ -88,6 +90,24 @@ class Backend {
8890
*/
8991
virtual int SyncEvent(void* event) = 0;
9092

93+
/**
94+
*\brief Get an address that is directly readable via the PCIe bus
95+
* @param devicePtr device physical address
96+
* @return an address that is directly readable via the PCIe bus
97+
*/
98+
virtual void* GetAccessibleAddr(void* devicePtr, size_t size) {
99+
return devicePtr;
100+
}
101+
102+
virtual void* GetAccessibleAddr(const at::Tensor& tensor) {
103+
return GetAccessibleAddr(tensor.data_ptr(),
104+
tensor.numel() * tensor.element_size());
105+
}
106+
107+
virtual void* GetDeviceAddrFromHostPtr(void* hostPtr, size_t size) {
108+
return hostPtr;
109+
}
110+
91111
/**
92112
* \brief Get the backend implementation
93113
* @return the backend implementation
@@ -98,21 +118,28 @@ class Backend {
98118
RegisterImpl(name, backend);
99119
}
100120

121+
static void RegisterLazy(const std::string& name,
122+
const std::function<Backend*(void)>& ctor);
101123
protected:
102124
Backend() = default;
103125

104126
private:
105127
static std::mutex backends_mutex_;
106128
static std::unordered_map<std::string, Backend*> backends_;
129+
static std::unordered_map<std::string, std::function<Backend*(void)>> backend_ctors_;
107130

108131
static Backend* GetImpl() {
109132
static Backend* backend_impl = nullptr;
110133
if (backend_impl == nullptr) {
111134
std::unique_lock<std::mutex> lock(backends_mutex_);
112135
std::string backend_type = "GPU";
113-
backend_type = Environment::Get()->find("STEPMESH_BAKCEND", backend_type);
114-
PS_CHECK_NE(backends_.find(backend_type), backends_.end())
115-
<< "failed to get backend impl: " << backend_type;
136+
backend_type = Environment::Get()->find("STEPMESH_BACKEND", backend_type);
137+
if (backends_.find(backend_type) == backends_.end()) {
138+
PS_CHECK_NE(backend_ctors_.find(backend_type), backend_ctors_.end())
139+
<< "failed to get backend impl: " << backend_type;
140+
backends_[backend_type] = backend_ctors_[backend_type]();
141+
}
142+
116143
backend_impl = backends_[backend_type];
117144
}
118145
return backend_impl;

include/ps/internal/assign_op.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
*/
77
#ifndef PS_INTERNAL_ASSIGN_OP_H_
88
#define PS_INTERNAL_ASSIGN_OP_H_
9+
10+
#include <cassert>
11+
912
#include "ps/internal/utils.h"
1013
namespace ps {
1114

@@ -24,7 +27,7 @@ enum AssignOp {
2427
* \brief return an assignment function: right op= left
2528
*/
2629
template <typename T>
27-
inline void AssignFunc(const T& lhs, AssignOp op, T* rhs) {
30+
inline void AssignFunc(const T& left, AssignOp op, T* right) {
2831
switch (op) {
2932
case ASSIGN:
3033
*right = left;
@@ -42,7 +45,7 @@ inline void AssignFunc(const T& lhs, AssignOp op, T* rhs) {
4245
*right /= left;
4346
break;
4447
default:
45-
LOG(FATAL) << "use AssignOpInt..";
48+
PS_LOG(FATAL) << "use AssignOpInt..";
4649
}
4750
}
4851

@@ -51,7 +54,7 @@ inline void AssignFunc(const T& lhs, AssignOp op, T* rhs) {
5154
* works for integers
5255
*/
5356
template <typename T>
54-
inline void AssignFuncInt(const T& lhs, AssignOp op, T* rhs) {
57+
inline void AssignFuncInt(const T& left, AssignOp op, T* right) {
5558
switch (op) {
5659
case ASSIGN:
5760
*right = left;

include/ps/internal/cpu_backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#include <memory>
99

10-
#include "ps/internal/backend.h"
10+
#include "ps/backend.h"
1111

1212
namespace ps {
1313

0 commit comments

Comments
 (0)