Skip to content

Commit 11cd05a

Browse files
authored
feat: support beam search for llm model[1/N]. (jd-opensource#135)
1 parent 1c81533 commit 11cd05a

36 files changed

+604
-27
lines changed

xllm/core/framework/batch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
batch_factory.cpp
1616
batch_input_builder.cpp
1717
mposition.cpp
18+
beam_search.h
1819
DEPS
1920
:request
2021
:runtime

xllm/core/framework/batch/batch.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
7575
mm_data_vec_,
7676
copy_in_cache_block_infos_,
7777
copy_out_cache_block_infos_,
78+
swap_cache_block_infos_,
7879
&args);
7980
return builder.build_forward_input(num_decoding_tokens,
8081
min_decoding_batch_size);
@@ -88,6 +89,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
8889
mm_data_vec_,
8990
copy_in_cache_block_infos_,
9091
copy_out_cache_block_infos_,
92+
swap_cache_block_infos_,
9193
nullptr);
9294
return builder.build_raw_forward_input(start_idx, end_idx);
9395
}
@@ -134,6 +136,7 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output,
134136
}
135137
}
136138
CHECK_EQ(output_idx, num_seqs);
139+
process_beam_search();
137140
}
138141

139142
void Batch::process_sample_output(const SampleOutput& sample_output,
@@ -175,6 +178,7 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
175178
append_token_for_sequence(seq, token, 0, enable_schedule_overlap);
176179
}
177180
CHECK_EQ(output_idx, num_seqs);
181+
process_beam_search();
178182
}
179183

180184
bool Batch::update_sequence_state(Sequence* seq, bool enable_schedule_overlap) {
@@ -246,4 +250,10 @@ void Batch::process_embedding_output(const torch::Tensor& output_embedding) {
246250
}
247251
}
248252
}
253+
254+
void Batch::process_beam_search() {
255+
for (auto* sequence_group : sequence_groups_) {
256+
sequence_group->process_beam_search();
257+
}
258+
}
249259
} // namespace xllm

xllm/core/framework/batch/batch.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ limitations under the License.
2222
#include <vector>
2323

2424
#include "framework/request/mm_data.h"
25+
#include "framework/request/request.h"
2526
#include "framework/request/sequence.h"
27+
#include "framework/request/sequences_group.h"
2628
#include "runtime/forward_params.h"
2729

2830
namespace xllm {
@@ -41,6 +43,10 @@ class Batch {
4143

4244
void add(const std::vector<Sequence*>& sequences);
4345

46+
void add(SequencesGroup* sequence_group) {
47+
sequence_groups_.push_back(sequence_group);
48+
};
49+
4450
void set_copy_in_cache_block_infos(
4551
std::vector<CacheBlockInfo>* copy_in_cache_block_infos) {
4652
copy_in_cache_block_infos_ = copy_in_cache_block_infos;
@@ -51,6 +57,11 @@ class Batch {
5157
copy_out_cache_block_infos_ = copy_out_cache_block_infos;
5258
}
5359

60+
void set_swap_cache_block_infos(
61+
std::vector<CacheBlockInfo>* swap_cache_block_infos) {
62+
swap_cache_block_infos_ = swap_cache_block_infos;
63+
}
64+
5465
// get the number of sequences in the batch
5566
size_t size() const { return sequences_.size(); }
5667
bool empty() const { return sequences_.empty(); }
@@ -93,9 +104,13 @@ class Batch {
93104
int token_idx,
94105
bool enable_schedule_overlap);
95106

107+
void process_beam_search();
108+
96109
std::vector<Sequence*> sequences_;
110+
std::vector<SequencesGroup*> sequence_groups_;
97111
std::vector<CacheBlockInfo>* copy_in_cache_block_infos_ = nullptr;
98112
std::vector<CacheBlockInfo>* copy_out_cache_block_infos_ = nullptr;
113+
std::vector<CacheBlockInfo>* swap_cache_block_infos_ = nullptr;
99114

100115
// max number of tokens to process for each sequence
101116
// default to max value

xllm/core/framework/batch/batch_factory.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,25 @@ limitations under the License.
1717

1818
namespace xllm {
1919

20+
namespace {
21+
22+
bool is_beam_search(const std::vector<std::shared_ptr<Request>>& requests) {
23+
for (const auto& request : requests) {
24+
if (request->check_beam_search()) {
25+
return true;
26+
}
27+
}
28+
return false;
29+
}
30+
} // namespace
31+
2032
std::vector<Batch> BatchFactory::create_batches(
33+
const std::vector<std::shared_ptr<Request>>& running_requests,
2134
const std::vector<Sequence*>& running_sequences,
2235
const std::vector<size_t>& running_sequences_budgets,
2336
std::vector<std::vector<CacheBlockInfo>>* copy_in_cache_block_infos,
24-
std::vector<std::vector<CacheBlockInfo>>* copy_out_cache_block_infos) {
37+
std::vector<std::vector<CacheBlockInfo>>* copy_out_cache_block_infos,
38+
std::vector<std::vector<CacheBlockInfo>>* swap_cache_block_infos) {
2539
size_t num_prompt_tokens = 0;
2640
size_t num_generated_tokens = 0;
2741
std::vector<Batch> batches(dp_size_);
@@ -50,6 +64,14 @@ std::vector<Batch> BatchFactory::create_batches(
5064
}
5165
}
5266

67+
if (is_beam_search(running_requests)) {
68+
for (const auto& request : running_requests) {
69+
auto seq_group = request->sequence_group();
70+
int32_t dp_rank = seq_group->dp_rank();
71+
batches[dp_rank].add(seq_group);
72+
}
73+
}
74+
5375
for (int i = 0; i < dp_size_; i++) {
5476
if (!batches[i].empty()) {
5577
if (copy_in_cache_block_infos != nullptr &&
@@ -62,6 +84,10 @@ std::vector<Batch> BatchFactory::create_batches(
6284
batches[i].set_copy_out_cache_block_infos(
6385
&(copy_out_cache_block_infos->at(i)));
6486
}
87+
if (swap_cache_block_infos != nullptr &&
88+
swap_cache_block_infos->size() == dp_size_) {
89+
batches[i].set_swap_cache_block_infos(&(swap_cache_block_infos->at(i)));
90+
}
6591
}
6692
}
6793

xllm/core/framework/batch/batch_factory.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ class BatchFactory {
2828
}
2929

3030
std::vector<Batch> create_batches(
31+
const std::vector<std::shared_ptr<Request>>& running_requests,
3132
const std::vector<Sequence*>& running_sequences,
3233
const std::vector<size_t>& running_sequences_budgets,
3334
std::vector<std::vector<CacheBlockInfo>>* copy_in_cache_block_infos =
3435
nullptr,
3536
std::vector<std::vector<CacheBlockInfo>>* copy_out_cache_block_infos =
37+
nullptr,
38+
std::vector<std::vector<CacheBlockInfo>>* swap_cache_block_infos =
3639
nullptr);
3740

3841
private:

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ BatchInputBuilder::BatchInputBuilder(
5757
const std::vector<MMData>& mm_data_vec,
5858
const std::vector<CacheBlockInfo>* copy_in_cache_block_infos,
5959
const std::vector<CacheBlockInfo>* copy_out_cache_block_infos,
60+
const std::vector<CacheBlockInfo>* swap_cache_block_infos,
6061
const ModelArgs* args)
6162
: sequences_(sequences),
6263
allowed_max_tokens_(allowed_max_tokens),
@@ -65,7 +66,8 @@ BatchInputBuilder::BatchInputBuilder(
6566
args_(args),
6667
num_sequences_(static_cast<int32_t>(sequences.size())),
6768
copy_in_cache_block_infos_(copy_in_cache_block_infos),
68-
copy_out_cache_block_infos_(copy_out_cache_block_infos) {
69+
copy_out_cache_block_infos_(copy_out_cache_block_infos),
70+
swap_cache_block_infos_(swap_cache_block_infos) {
6971
// Reserve space for better performance
7072
state_.flatten_tokens_vec.reserve(1000);
7173
state_.flatten_positions_vec.reserve(1000);
@@ -348,6 +350,13 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
348350
input_params.input_embedding = torch::cat(input_embeddings_vec_);
349351
}
350352

353+
if (swap_cache_block_infos_ != nullptr &&
354+
swap_cache_block_infos_->size() > 0) {
355+
input_params.swap_blocks.insert(input_params.swap_blocks.end(),
356+
swap_cache_block_infos_->begin(),
357+
swap_cache_block_infos_->end());
358+
}
359+
351360
CHECK_EQ(state_.sampling_params.size(), state_.selected_token_idxes.size());
352361
// Setup sampling parameters
353362
if (!state_.selected_token_idxes.empty()) {
@@ -427,6 +436,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
427436
copy_in_cache_block_infos_->begin(),
428437
copy_in_cache_block_infos_->end());
429438
}
439+
if (swap_cache_block_infos_ != nullptr &&
440+
swap_cache_block_infos_->size() > 0) {
441+
raw_forward_input.swap_blocks.insert(raw_forward_input.swap_blocks.end(),
442+
swap_cache_block_infos_->begin(),
443+
swap_cache_block_infos_->end());
444+
}
430445

431446
split_copy_out_blocks(raw_forward_input, write_block_ids_);
432447

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class BatchInputBuilder {
3838
const std::vector<MMData>& mm_data_vec,
3939
const std::vector<CacheBlockInfo>* copy_in_cache_block_infos,
4040
const std::vector<CacheBlockInfo>* copy_out_cache_block_infos,
41+
const std::vector<CacheBlockInfo>* swap_cache_block_infos,
4142
const ModelArgs* args);
4243

4344
ForwardInput build_forward_input(uint32_t num_decoding_tokens,
@@ -125,6 +126,7 @@ class BatchInputBuilder {
125126
std::unordered_set<int32_t> write_block_ids_;
126127
const std::vector<CacheBlockInfo>* copy_in_cache_block_infos_ = nullptr;
127128
const std::vector<CacheBlockInfo>* copy_out_cache_block_infos_ = nullptr;
129+
const std::vector<CacheBlockInfo>* swap_cache_block_infos_ = nullptr;
128130
};
129131

130132
} // namespace xllm
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
namespace xllm {
19+
20+
// BeamCandidate structure for beam search sorting
21+
struct BeamCandidate {
22+
size_t seq_index;
23+
float logprob_sum;
24+
std::vector<int32_t> token_ids;
25+
std::vector<std::optional<float>> logprobs;
26+
27+
BeamCandidate() = default;
28+
29+
BeamCandidate(size_t seq_idx,
30+
float logprob,
31+
std::vector<int32_t>& token_ids,
32+
std::vector<std::optional<float>>& logprobs)
33+
: seq_index(seq_idx),
34+
logprob_sum(logprob),
35+
token_ids(std::move(token_ids)),
36+
logprobs(std::move(logprobs)) {}
37+
38+
bool operator<(const BeamCandidate& other) const {
39+
return logprob_sum > other.logprob_sum;
40+
}
41+
};
42+
43+
template <typename CandidateType>
44+
class SimpleTopKOptimizer {
45+
private:
46+
std::priority_queue<CandidateType> min_heap_;
47+
size_t k_;
48+
49+
public:
50+
explicit SimpleTopKOptimizer(size_t k) : k_(k) {}
51+
52+
void clear() {
53+
while (!min_heap_.empty()) {
54+
min_heap_.pop();
55+
}
56+
}
57+
58+
void insert(const CandidateType& candidate) {
59+
if (min_heap_.size() < k_) {
60+
min_heap_.push(candidate);
61+
} else if (candidate.logprob_sum > min_heap_.top().logprob_sum) {
62+
min_heap_.pop();
63+
min_heap_.push(candidate);
64+
}
65+
}
66+
67+
void insert(CandidateType&& candidate) {
68+
if (min_heap_.size() < k_) {
69+
min_heap_.push(std::move(candidate));
70+
} else if (candidate.logprob_sum > min_heap_.top().logprob_sum) {
71+
min_heap_.pop();
72+
min_heap_.push(std::move(candidate));
73+
}
74+
}
75+
76+
void insert_batch(const std::vector<CandidateType>& candidates) {
77+
for (const auto& candidate : candidates) {
78+
insert(candidate);
79+
}
80+
}
81+
82+
std::vector<CandidateType> getTopK() {
83+
std::vector<CandidateType> result;
84+
result.reserve(min_heap_.size());
85+
86+
while (!min_heap_.empty()) {
87+
result.emplace_back(
88+
std::move(const_cast<CandidateType&>(min_heap_.top())));
89+
min_heap_.pop();
90+
}
91+
92+
return result;
93+
}
94+
95+
std::vector<CandidateType>&& getTopKMove() {
96+
std::vector<CandidateType> result;
97+
result.reserve(min_heap_.size());
98+
99+
while (!min_heap_.empty()) {
100+
result.emplace_back(
101+
std::move(const_cast<CandidateType&>(min_heap_.top())));
102+
min_heap_.pop();
103+
}
104+
105+
return std::move(result);
106+
}
107+
108+
std::vector<CandidateType> getTopKSorted() {
109+
std::vector<CandidateType> result = getTopK();
110+
std::reverse(result.begin(), result.end());
111+
return result;
112+
}
113+
114+
size_t size() const { return min_heap_.size(); }
115+
116+
bool empty() const { return min_heap_.empty(); }
117+
118+
bool worthInserting(float logprob_sum) const {
119+
return min_heap_.size() < k_ || logprob_sum > min_heap_.top().logprob_sum;
120+
}
121+
122+
float getMinLogprob() const {
123+
return min_heap_.empty() ? -std::numeric_limits<float>::infinity()
124+
: min_heap_.top().logprob_sum;
125+
}
126+
};
127+
128+
using SimpleTopKOptimizerBeamCandidate = SimpleTopKOptimizer<BeamCandidate>;
129+
130+
} // namespace xllm

0 commit comments

Comments
 (0)