1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " rec_constrained_decoding.h"
17+
18+ #include < c10/core/TensorOptions.h>
19+ #include < folly/Unit.h>
20+ #include < folly/futures/Future.h>
21+ #include < glog/logging.h>
22+
23+ #include < algorithm>
24+ #include < filesystem>
25+ #include < fstream>
26+ #include < future>
27+ #include < mutex>
28+
29+ #include " common/global_flags.h"
30+ #include " common/version_singleton.h"
31+ #include " framework/state_dict/rec_vocab_dict.h"
32+ #include " util/slice.h"
33+ #include " util/tensor_helper.h"
34+
35+ namespace xllm {
36+ RecConstrainedDecoding::RecConstrainedDecoding (uint64_t model_version,
37+ const int32_t vocab_size,
38+ torch::ScalarType dtype,
39+ torch::Device device,
40+ bool use_gen_threadpool)
41+ : use_gen_threadpool_(use_gen_threadpool),
42+ vocab_size_ (vocab_size),
43+ model_version_(model_version),
44+ device_(device),
45+ dtype_(dtype) {
46+ if (use_gen_threadpool_) {
47+ gen_threadpool_ = std::make_unique<ThreadPool>(GEN_MASK_THREAD_NUM);
48+ }
49+
50+ build_mask_cache_ = false ;
51+ }
52+
53+ bool RecConstrainedDecoding::build_mask_cache () {
54+ first_token_mask_ = torch::full ({vocab_size_}, PRE_MASK_FACTOR, dtype_);
55+
56+ std::vector<int32_t > empty_token_ids;
57+ Slice<int32_t > prefix_token_ids = {empty_token_ids.data (),
58+ empty_token_ids.size ()};
59+
60+ const std::set<int32_t >& first_token_ids =
61+ VersionSingleton<RecVocabDict>::GetInstance (
62+ std::to_string (model_version_))
63+ ->get_next_tokens_by_prefix_tokens (prefix_token_ids);
64+
65+ for (auto token_id : first_token_ids) {
66+ first_token_mask_[token_id] = 0 ;
67+ }
68+
69+ first_token_mask_ = safe_to (first_token_mask_, device_, true );
70+
71+ build_mask_cache_ = true ;
72+
73+ LOG (INFO) << " Build mask cache, first token ids size:"
74+ << first_token_ids.size ();
75+
76+ return true ;
77+ }
78+
79+ torch::Tensor RecConstrainedDecoding::generate_mask (
80+ const std::vector<std::vector<int32_t >>& generated_token_list) {
81+ if (!build_mask_cache_ || 0 == generated_token_list.size ()) {
82+ return torch::Tensor ();
83+ }
84+
85+ size_t token_size = generated_token_list[0 ].size ();
86+
87+ // Generate mask for first token
88+ if (0 == token_size) {
89+ size_t sequence_num = generated_token_list.size ();
90+ auto mask = first_token_mask_.unsqueeze (0 );
91+ return mask.repeat ({sequence_num, 1 });
92+ }
93+
94+ // Generate mask for non-first token
95+ return generate_decode_mask (generated_token_list);
96+ }
97+
98+ torch::Tensor RecConstrainedDecoding::generate_decode_mask (
99+ const std::vector<std::vector<int32_t >>& generated_token_list) {
100+ size_t sequence_num = generated_token_list.size ();
101+ torch::TensorOptions options = torch::dtype (dtype_).device (device_);
102+ auto mask =
103+ torch::full ({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options);
104+
105+ std::mutex global_batch_mutex;
106+ std::vector<int64_t > global_batch_token_indices;
107+ std::vector<int64_t > global_batch_vocab_indices;
108+
109+ int max_index_num_per_token = 8192 ;
110+ global_batch_token_indices.reserve (max_index_num_per_token * sequence_num);
111+ global_batch_vocab_indices.reserve (max_index_num_per_token * sequence_num);
112+
113+ auto update_mask = [&](size_t start_idx, size_t end_idx) {
114+ std::vector<int64_t > local_token_indices;
115+ std::vector<int64_t > local_vocab_indices;
116+ local_token_indices.reserve (max_index_num_per_token *
117+ (end_idx - start_idx));
118+ local_vocab_indices.reserve (max_index_num_per_token *
119+ (end_idx - start_idx));
120+
121+ for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) {
122+ Slice<int32_t > tokens_slice (generated_token_list[token_idx]);
123+
124+ const std::set<int32_t >& next_token_ids =
125+ VersionSingleton<RecVocabDict>::GetInstance (
126+ std::to_string (model_version_))
127+ ->get_next_tokens_by_prefix_tokens (tokens_slice);
128+
129+ if (next_token_ids.size () > 0 ) {
130+ for (int32_t vocab_idx : next_token_ids) {
131+ local_token_indices.push_back (static_cast <int64_t >(token_idx));
132+ local_vocab_indices.push_back (static_cast <int64_t >(vocab_idx));
133+ }
134+ } else {
135+ LOG (ERROR) << " Fail to generate mask for tokens:"
136+ << generated_token_list[token_idx];
137+ }
138+ }
139+
140+ // Merge local results to global batch (thread-safe)
141+ if (!local_token_indices.empty ()) {
142+ std::lock_guard<std::mutex> lock (global_batch_mutex);
143+ global_batch_token_indices.insert (global_batch_token_indices.end (),
144+ local_token_indices.begin (),
145+ local_token_indices.end ());
146+ global_batch_vocab_indices.insert (global_batch_vocab_indices.end (),
147+ local_vocab_indices.begin (),
148+ local_vocab_indices.end ());
149+ }
150+ };
151+
152+ if (use_gen_threadpool_) {
153+ const size_t batch_size = std::max (
154+ 1UL , (sequence_num + GEN_MASK_THREAD_NUM - 1 ) / GEN_MASK_THREAD_NUM);
155+ const size_t num_batches = (sequence_num + batch_size - 1 ) / batch_size;
156+
157+ std::vector<std::future<void >> futures;
158+ std::vector<std::shared_ptr<std::promise<void >>> promises;
159+
160+ promises.reserve (num_batches);
161+ futures.reserve (num_batches);
162+
163+ for (size_t batch_idx = 0 ; batch_idx < num_batches; ++batch_idx) {
164+ auto promise = std::make_shared<std::promise<void >>();
165+ futures.push_back (promise->get_future ());
166+ promises.push_back (promise);
167+
168+ size_t start_idx = batch_idx * batch_size;
169+ size_t end_idx = std::min (start_idx + batch_size, sequence_num);
170+
171+ gen_threadpool_->schedule (
172+ [update_mask, start_idx, end_idx, promise]() mutable {
173+ update_mask (start_idx, end_idx);
174+ promise->set_value ();
175+ });
176+ }
177+
178+ for (auto & future : futures) {
179+ future.get ();
180+ }
181+ } else {
182+ update_mask (0 , sequence_num);
183+ }
184+
185+ if (!global_batch_token_indices.empty ()) {
186+ auto token_indices =
187+ torch::tensor (global_batch_token_indices, torch::kInt64 );
188+ auto vocab_indices =
189+ torch::tensor (global_batch_vocab_indices, torch::kInt64 );
190+ token_indices = safe_to (token_indices, device_, true );
191+ vocab_indices = safe_to (vocab_indices, device_, true );
192+ mask.index_put_ ({token_indices, vocab_indices}, 0 .0f );
193+ }
194+
195+ return mask;
196+ }
197+ } // namespace xllm
0 commit comments