Skip to content

Commit 47bcce5

Browse files
authored
bugfix: fix unsupported synchronizing streams. (jd-opensource#351)
Signed-off-by: Tao Peng <[email protected]>
1 parent 5d89633 commit 47bcce5

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

xllm/core/platform/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ int64_t Device::free_memory() { return get_device_mem().free_memory; }
101101

102102
int Device::synchronize_default_stream() {
103103
#if defined(USE_NPU)
104-
c10_npu::getCurrentNPUStream(index()).synchronize();
104+
return aclrtSynchronizeStream(c10_npu::getCurrentNPUStream(index()).stream());
105105
#elif defined(USE_MLU)
106106
torch_mlu::getCurrentMLUStream(index()).synchronize();
107107
#elif defined(USE_CUDA)

xllm/core/platform/stream.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ Stream::Stream() : stream_(c10::cuda::getStreamFromPool()) {}
2626
#endif
2727

2828
int Stream::synchronize() const {
29+
#if defined(USE_NPU)
30+
return aclrtSynchronizeStream(stream_.stream());
31+
#else
2932
stream_.unwrap().synchronize();
3033
return 0;
34+
#endif
3135
}
3236

3337
c10::StreamGuard Stream::set_stream_guard() const {

xllm/core/scheduler/continuous_scheduler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ void ContinuousScheduler::generate() {
960960
while (num_pending_requests() > 0 || !batch_empty ||
961961
request_queue_.size() > 0) {
962962
// build a batch of requests/sequences
963-
auto batch = prepare_batch();
963+
const auto timeout = absl::Milliseconds(500);
964+
std::vector<Batch> batch = schedule_request(timeout);
964965
batch_empty = true;
965966
for (auto& b : batch) {
966967
batch_empty &= b.empty();

xllm/models/llm/llm_model_base.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class LlmDecoderLayerImplBase : public torch::nn::Module {
9696
int node_id,
9797
std::vector<aclrtEvent*> event,
9898
std::vector<std::atomic<bool>*> event_flag) {
99-
#if defined(USE_NPU)
10099
auto micro_batch_num = x.size();
101100
for (auto i = 0; i < micro_batch_num; ++i) {
102101
if (input_params[i].src_block_indices.numel() > 0) {
@@ -108,7 +107,7 @@ class LlmDecoderLayerImplBase : public torch::nn::Module {
108107
0);
109108
}
110109
}
111-
#endif
110+
112111
return decoder_layer_(x,
113112
cos_pos,
114113
sin_pos,

0 commit comments

Comments
 (0)