Skip to content

Commit e4916ec

Browse files
authored
feat: set same random seed for all worker. (jd-opensource#483)
1 parent 32f3e01 commit e4916ec

22 files changed

+114
-48
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ DEFINE_bool(
424424
"The default prefetching ratio for gateup weight is 40%."
425425
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");
426426

427-
// rec prefill-only mode
427+
// --- rec prefill-only mode ---
428428
DEFINE_bool(enable_rec_prefill_only,
429429
false,
430430
"Enable rec prefill-only mode (no decoder self-attention blocks "
@@ -438,6 +438,9 @@ DEFINE_bool(
438438
"Whether to enable dp load balance, if true, sequences within a single "
439439
"dp batch will be shuffled.");
440440

441+
// --- the seed for random number generator ---
442+
DEFINE_int32(random_seed, -1, "Random seed for random number generator.");
443+
441444
// --- dit cache config ---
442445

443446
DEFINE_string(dit_cache_policy,

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ DECLARE_int32(flashinfer_workspace_buffer_size);
215215

216216
DECLARE_bool(enable_dp_balance);
217217

218+
DECLARE_int32(random_seed);
219+
218220
DECLARE_string(dit_cache_policy);
219221

220222
DECLARE_int64(dit_cache_warmup_steps);

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,12 @@ bool CommChannel::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
211211
return true;
212212
}
213213

214-
bool CommChannel::init_model(const std::string& model_weights_path) {
215-
proto::ModelPath request;
214+
bool CommChannel::init_model(const std::string& model_weights_path,
215+
int32_t random_seed) {
216+
proto::InitModelRequest request;
216217

217218
request.set_model_weights_path(model_weights_path);
219+
request.set_random_seed(random_seed);
218220
proto::Status response;
219221
brpc::Controller cntl;
220222
stub_->InitModel(&cntl, &request, &response, nullptr);
@@ -226,10 +228,12 @@ bool CommChannel::init_model(const std::string& model_weights_path) {
226228
}
227229

228230
bool CommChannel::init_model_async(const std::string& model_weights_path,
231+
int32_t random_seed,
229232
folly::Promise<bool>& promise) {
230-
proto::ModelPath request;
233+
proto::InitModelRequest request;
231234

232235
request.set_model_weights_path(model_weights_path);
236+
request.set_random_seed(random_seed);
233237
auto done = new InitModelClosure();
234238
done->promise = std::move(promise);
235239
stub_->InitModel(&done->cntl, &request, &done->response, done);

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ class CommChannel {
6262
const std::vector<std::string>& device_ips,
6363
const std::vector<uint16_t>& ports);
6464

65-
virtual bool init_model(const std::string& model_weights_path);
65+
virtual bool init_model(const std::string& model_weights_path,
66+
int32_t random_seed);
6667

6768
virtual bool init_model_async(const std::string& model_weights_path,
69+
int32_t random_seed,
6870
folly::Promise<bool>& promise);
6971

7072
virtual bool estimate_kv_cache_capacity(int64_t& available_memory,

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ bool RemoteWorker::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
104104
return channel_->unlink_cluster(cluster_ids, addrs, device_ips, ports);
105105
}
106106

107-
bool RemoteWorker::init_model(const std::string& model_weights_path) {
108-
return channel_->init_model(model_weights_path);
107+
bool RemoteWorker::init_model(const std::string& model_weights_path,
108+
int32_t random_seed) {
109+
return channel_->init_model(model_weights_path, random_seed);
109110
}
110111

111112
std::tuple<int64_t, int64_t> RemoteWorker::estimate_kv_cache_capacity() {
@@ -190,14 +191,17 @@ folly::SemiFuture<folly::Unit> RemoteWorker::process_group_test_async() {
190191
}
191192

192193
folly::SemiFuture<bool> RemoteWorker::init_model_async(
193-
const std::string& model_weights_path) {
194+
const std::string& model_weights_path,
195+
int32_t random_seed) {
194196
folly::Promise<bool> promise;
195197
auto future = promise.getSemiFuture();
196-
threadpool_.schedule(
197-
[this, model_weights_path, promise = std::move(promise)]() mutable {
198-
// call InitModel with callback
199-
channel_->init_model_async(model_weights_path, promise);
200-
});
198+
threadpool_.schedule([this,
199+
model_weights_path,
200+
random_seed,
201+
promise = std::move(promise)]() mutable {
202+
// call InitModel with callback
203+
channel_->init_model_async(model_weights_path, random_seed, promise);
204+
});
201205
return future;
202206
}
203207

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class RemoteWorker : public WorkerClient {
4646

4747
bool wait_for_server_ready(const std::string& server_address);
4848

49-
virtual bool init_model(const std::string& model_weights_path) override;
49+
virtual bool init_model(const std::string& model_weights_path,
50+
int32_t random_seed) override;
5051

5152
virtual std::tuple<int64_t, int64_t> estimate_kv_cache_capacity() override;
5253

@@ -87,7 +88,8 @@ class RemoteWorker : public WorkerClient {
8788
const ForwardInput& inputs) override;
8889

8990
virtual folly::SemiFuture<bool> init_model_async(
90-
const std::string& model_weights_path) override;
91+
const std::string& model_weights_path,
92+
int32_t random_seed) override;
9193

9294
virtual folly::SemiFuture<std::tuple<int64_t, int64_t>>
9395
estimate_kv_cache_capacity_async() override;

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,15 @@ void WorkerService::Hello(::google::protobuf::RpcController* controller,
220220
}
221221

222222
void WorkerService::InitModel(::google::protobuf::RpcController* controller,
223-
const proto::ModelPath* request,
223+
const proto::InitModelRequest* request,
224224
proto::Status* response,
225225
::google::protobuf::Closure* done) {
226226
threadpool_->schedule([this, controller, request, response, done]() mutable {
227227
brpc::ClosureGuard done_guard(done);
228228
auto model_weights_path = request->model_weights_path();
229-
auto init_future = worker_->init_model_async(model_weights_path);
229+
auto random_seed = request->random_seed();
230+
auto init_future =
231+
worker_->init_model_async(model_weights_path, random_seed);
230232
bool status = std::move(init_future).get();
231233
if (!status) {
232234
response->set_ok(false);

xllm/core/distributed_runtime/worker_service.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class WorkerService : public proto::DistributeWorker {
4545
::google::protobuf::Closure* done) override;
4646

4747
void InitModel(::google::protobuf::RpcController* controller,
48-
const proto::ModelPath* request,
48+
const proto::InitModelRequest* request,
4949
proto::Status* response,
5050
::google::protobuf::Closure* done) override;
5151

xllm/core/platform/device.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "device.h"
17-
#if defined(USE_MLU)
17+
#if defined(USE_NPU)
18+
#include <torch_npu/csrc/aten/NPUGeneratorImpl.h>
19+
#elif defined(USE_MLU)
1820
#include <cn_api.h>
1921
#include <torch_mlu/csrc/framework/core/device.h>
2022
#include <torch_mlu/csrc/framework/core/device_utils.h>
23+
#include <torch_mlu/csrc/framework/generator/generator_impl.h>
2124
#elif defined(USE_CUDA)
2225
#include <c10/cuda/CUDAStream.h>
2326
#include <cuda.h>
@@ -39,6 +42,22 @@ void Device::set_device() const {
3942
#endif
4043
}
4144

45+
void Device::set_seed(uint64_t seed) const {
46+
torch::manual_seed(seed);
47+
#if defined(USE_NPU)
48+
auto gen = at_npu::detail::getDefaultNPUGenerator(index());
49+
gen.set_current_seed(seed);
50+
#elif defined(USE_MLU)
51+
auto gen = torch_mlu::getDefaultMLUGenerator(index());
52+
{
53+
std::lock_guard<std::mutex> lock(gen.mutex());
54+
gen.set_current_seed(seed);
55+
}
56+
#elif defined(USE_CUDA)
57+
torch::cuda::manual_seed(seed);
58+
#endif
59+
}
60+
4261
const torch::Device& Device::unwrap() const { return device_; }
4362

4463
int32_t Device::index() const { return device_.index(); }

xllm/core/platform/device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class Device {
3333

3434
void set_device() const;
3535

36+
void set_seed(uint64_t seed = 42) const;
37+
3638
const torch::Device& unwrap() const;
3739
int32_t index() const;
3840

0 commit comments

Comments
 (0)