1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #pragma once
17+
18+ namespace xllm {
19+
20+ // BeamCandidate structure for beam search sorting
21+ struct BeamCandidate {
22+ size_t seq_index;
23+ float logprob_sum;
24+ std::vector<int32_t > token_ids;
25+ std::vector<std::optional<float >> logprobs;
26+
27+ BeamCandidate () = default ;
28+
29+ BeamCandidate (size_t seq_idx,
30+ float logprob,
31+ std::vector<int32_t >& token_ids,
32+ std::vector<std::optional<float >>& logprobs)
33+ : seq_index(seq_idx),
34+ logprob_sum (logprob),
35+ token_ids(std::move(token_ids)),
36+ logprobs(std::move(logprobs)) {}
37+
38+ bool operator <(const BeamCandidate& other) const {
39+ return logprob_sum > other.logprob_sum ;
40+ }
41+ };
42+
43+ template <typename CandidateType>
44+ class SimpleTopKOptimizer {
45+ private:
46+ std::priority_queue<CandidateType> min_heap_;
47+ size_t k_;
48+
49+ public:
50+ explicit SimpleTopKOptimizer (size_t k) : k_(k) {}
51+
52+ void clear () {
53+ while (!min_heap_.empty ()) {
54+ min_heap_.pop ();
55+ }
56+ }
57+
58+ void insert (const CandidateType& candidate) {
59+ if (min_heap_.size () < k_) {
60+ min_heap_.push (candidate);
61+ } else if (candidate.logprob_sum > min_heap_.top ().logprob_sum ) {
62+ min_heap_.pop ();
63+ min_heap_.push (candidate);
64+ }
65+ }
66+
67+ void insert (CandidateType&& candidate) {
68+ if (min_heap_.size () < k_) {
69+ min_heap_.push (std::move (candidate));
70+ } else if (candidate.logprob_sum > min_heap_.top ().logprob_sum ) {
71+ min_heap_.pop ();
72+ min_heap_.push (std::move (candidate));
73+ }
74+ }
75+
76+ void insert_batch (const std::vector<CandidateType>& candidates) {
77+ for (const auto & candidate : candidates) {
78+ insert (candidate);
79+ }
80+ }
81+
82+ std::vector<CandidateType> getTopK () {
83+ std::vector<CandidateType> result;
84+ result.reserve (min_heap_.size ());
85+
86+ while (!min_heap_.empty ()) {
87+ result.emplace_back (
88+ std::move (const_cast <CandidateType&>(min_heap_.top ())));
89+ min_heap_.pop ();
90+ }
91+
92+ return result;
93+ }
94+
95+ std::vector<CandidateType>&& getTopKMove() {
96+ std::vector<CandidateType> result;
97+ result.reserve (min_heap_.size ());
98+
99+ while (!min_heap_.empty ()) {
100+ result.emplace_back (
101+ std::move (const_cast <CandidateType&>(min_heap_.top ())));
102+ min_heap_.pop ();
103+ }
104+
105+ return std::move (result);
106+ }
107+
108+ std::vector<CandidateType> getTopKSorted () {
109+ std::vector<CandidateType> result = getTopK ();
110+ std::reverse (result.begin (), result.end ());
111+ return result;
112+ }
113+
114+ size_t size () const { return min_heap_.size (); }
115+
116+ bool empty () const { return min_heap_.empty (); }
117+
118+ bool worthInserting (float logprob_sum) const {
119+ return min_heap_.size () < k_ || logprob_sum > min_heap_.top ().logprob_sum ;
120+ }
121+
122+ float getMinLogprob () const {
123+ return min_heap_.empty () ? -std::numeric_limits<float >::infinity ()
124+ : min_heap_.top ().logprob_sum ;
125+ }
126+ };
127+
128+ using SimpleTopKOptimizerBeamCandidate = SimpleTopKOptimizer<BeamCandidate>;
129+
130+ } // namespace xllm
0 commit comments