Skip to content

Commit 32f3e01

Browse files
feat: add constrained decoding for generative recommendation (jd-opensource#480)
1 parent 5f60bb0 commit 32f3e01

File tree

4 files changed

+304
-0
lines changed

4 files changed

+304
-0
lines changed

xllm/core/framework/sampling/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ cc_library(
1010
rejection_sampler.h
1111
sampler.h
1212
beam_searcher.h
13+
rec_constrained_decoding.h
1314
SRCS
1415
sampling_params.cpp
1516
logits_utils.cpp
1617
rejection_sampler.cpp
1718
sampler.cpp
1819
beam_searcher.cpp
20+
rec_constrained_decoding.cpp
1921
DEPS
2022
:common
2123
glog::glog
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include <c10/core/TensorOptions.h>
18+
#include <torch/torch.h>
19+
#include <torch/types.h>
20+
21+
namespace xllm {
22+
23+
// Constrained decoding is used to ensure that the generated content
24+
// conforms to specific formats or rules.
25+
class ConstrainedDecoding {
26+
public:
27+
virtual ~ConstrainedDecoding() = default;
28+
29+
// Precompute and cache fixed constraint masks (e.g., static vocabulary
30+
// whitelists) to avoid redundant calculations during token generation.
31+
// Returns: true if cache built successfully, false otherwise
32+
virtual bool build_mask_cache() = 0;
33+
34+
// Generate dynamic constraint mask based on already generated token
35+
// sequences. This mask will be applied to filter invalid tokens.
36+
//
37+
// Input: generated_token_list - 2D vector of token IDs, where each inner
38+
// vector represents the generated tokens for a single sequence in the batch
39+
// (format:[sequence_num][token_ids])
40+
// Output: tensor of shape [sequence_num, vocab_size], where 0.0f
41+
// indicates allowed tokens and a large negative number indicates forbidden
42+
// tokens for each sequence, the usage is to filter invalid tokens by adding
43+
// the mask to the model logits.
44+
virtual torch::Tensor generate_mask(
45+
const std::vector<std::vector<int32_t>>& generated_token_list) = 0;
46+
};
47+
} // namespace xllm
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "rec_constrained_decoding.h"
17+
18+
#include <c10/core/TensorOptions.h>
19+
#include <folly/Unit.h>
20+
#include <folly/futures/Future.h>
21+
#include <glog/logging.h>
22+
23+
#include <algorithm>
24+
#include <filesystem>
25+
#include <fstream>
26+
#include <future>
27+
#include <mutex>
28+
29+
#include "common/global_flags.h"
30+
#include "common/version_singleton.h"
31+
#include "framework/state_dict/rec_vocab_dict.h"
32+
#include "util/slice.h"
33+
#include "util/tensor_helper.h"
34+
35+
namespace xllm {
36+
RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version,
37+
const int32_t vocab_size,
38+
torch::ScalarType dtype,
39+
torch::Device device,
40+
bool use_gen_threadpool)
41+
: use_gen_threadpool_(use_gen_threadpool),
42+
vocab_size_(vocab_size),
43+
model_version_(model_version),
44+
device_(device),
45+
dtype_(dtype) {
46+
if (use_gen_threadpool_) {
47+
gen_threadpool_ = std::make_unique<ThreadPool>(GEN_MASK_THREAD_NUM);
48+
}
49+
50+
build_mask_cache_ = false;
51+
}
52+
53+
bool RecConstrainedDecoding::build_mask_cache() {
54+
first_token_mask_ = torch::full({vocab_size_}, PRE_MASK_FACTOR, dtype_);
55+
56+
std::vector<int32_t> empty_token_ids;
57+
Slice<int32_t> prefix_token_ids = {empty_token_ids.data(),
58+
empty_token_ids.size()};
59+
60+
const std::set<int32_t>& first_token_ids =
61+
VersionSingleton<RecVocabDict>::GetInstance(
62+
std::to_string(model_version_))
63+
->get_next_tokens_by_prefix_tokens(prefix_token_ids);
64+
65+
for (auto token_id : first_token_ids) {
66+
first_token_mask_[token_id] = 0;
67+
}
68+
69+
first_token_mask_ = safe_to(first_token_mask_, device_, true);
70+
71+
build_mask_cache_ = true;
72+
73+
LOG(INFO) << "Build mask cache, first token ids size:"
74+
<< first_token_ids.size();
75+
76+
return true;
77+
}
78+
79+
torch::Tensor RecConstrainedDecoding::generate_mask(
80+
const std::vector<std::vector<int32_t>>& generated_token_list) {
81+
if (!build_mask_cache_ || 0 == generated_token_list.size()) {
82+
return torch::Tensor();
83+
}
84+
85+
size_t token_size = generated_token_list[0].size();
86+
87+
// Generate mask for first token
88+
if (0 == token_size) {
89+
size_t sequence_num = generated_token_list.size();
90+
auto mask = first_token_mask_.unsqueeze(0);
91+
return mask.repeat({sequence_num, 1});
92+
}
93+
94+
// Generate mask for non-first token
95+
return generate_decode_mask(generated_token_list);
96+
}
97+
98+
torch::Tensor RecConstrainedDecoding::generate_decode_mask(
99+
const std::vector<std::vector<int32_t>>& generated_token_list) {
100+
size_t sequence_num = generated_token_list.size();
101+
torch::TensorOptions options = torch::dtype(dtype_).device(device_);
102+
auto mask =
103+
torch::full({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options);
104+
105+
std::mutex global_batch_mutex;
106+
std::vector<int64_t> global_batch_token_indices;
107+
std::vector<int64_t> global_batch_vocab_indices;
108+
109+
int max_index_num_per_token = 8192;
110+
global_batch_token_indices.reserve(max_index_num_per_token * sequence_num);
111+
global_batch_vocab_indices.reserve(max_index_num_per_token * sequence_num);
112+
113+
auto update_mask = [&](size_t start_idx, size_t end_idx) {
114+
std::vector<int64_t> local_token_indices;
115+
std::vector<int64_t> local_vocab_indices;
116+
local_token_indices.reserve(max_index_num_per_token *
117+
(end_idx - start_idx));
118+
local_vocab_indices.reserve(max_index_num_per_token *
119+
(end_idx - start_idx));
120+
121+
for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) {
122+
Slice<int32_t> tokens_slice(generated_token_list[token_idx]);
123+
124+
const std::set<int32_t>& next_token_ids =
125+
VersionSingleton<RecVocabDict>::GetInstance(
126+
std::to_string(model_version_))
127+
->get_next_tokens_by_prefix_tokens(tokens_slice);
128+
129+
if (next_token_ids.size() > 0) {
130+
for (int32_t vocab_idx : next_token_ids) {
131+
local_token_indices.push_back(static_cast<int64_t>(token_idx));
132+
local_vocab_indices.push_back(static_cast<int64_t>(vocab_idx));
133+
}
134+
} else {
135+
LOG(ERROR) << "Fail to generate mask for tokens:"
136+
<< generated_token_list[token_idx];
137+
}
138+
}
139+
140+
// Merge local results to global batch (thread-safe)
141+
if (!local_token_indices.empty()) {
142+
std::lock_guard<std::mutex> lock(global_batch_mutex);
143+
global_batch_token_indices.insert(global_batch_token_indices.end(),
144+
local_token_indices.begin(),
145+
local_token_indices.end());
146+
global_batch_vocab_indices.insert(global_batch_vocab_indices.end(),
147+
local_vocab_indices.begin(),
148+
local_vocab_indices.end());
149+
}
150+
};
151+
152+
if (use_gen_threadpool_) {
153+
const size_t batch_size = std::max(
154+
1UL, (sequence_num + GEN_MASK_THREAD_NUM - 1) / GEN_MASK_THREAD_NUM);
155+
const size_t num_batches = (sequence_num + batch_size - 1) / batch_size;
156+
157+
std::vector<std::future<void>> futures;
158+
std::vector<std::shared_ptr<std::promise<void>>> promises;
159+
160+
promises.reserve(num_batches);
161+
futures.reserve(num_batches);
162+
163+
for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
164+
auto promise = std::make_shared<std::promise<void>>();
165+
futures.push_back(promise->get_future());
166+
promises.push_back(promise);
167+
168+
size_t start_idx = batch_idx * batch_size;
169+
size_t end_idx = std::min(start_idx + batch_size, sequence_num);
170+
171+
gen_threadpool_->schedule(
172+
[update_mask, start_idx, end_idx, promise]() mutable {
173+
update_mask(start_idx, end_idx);
174+
promise->set_value();
175+
});
176+
}
177+
178+
for (auto& future : futures) {
179+
future.get();
180+
}
181+
} else {
182+
update_mask(0, sequence_num);
183+
}
184+
185+
if (!global_batch_token_indices.empty()) {
186+
auto token_indices =
187+
torch::tensor(global_batch_token_indices, torch::kInt64);
188+
auto vocab_indices =
189+
torch::tensor(global_batch_vocab_indices, torch::kInt64);
190+
token_indices = safe_to(token_indices, device_, true);
191+
vocab_indices = safe_to(vocab_indices, device_, true);
192+
mask.index_put_({token_indices, vocab_indices}, 0.0f);
193+
}
194+
195+
return mask;
196+
}
197+
} // namespace xllm
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include <torch/torch.h>
18+
#include <torch/types.h>
19+
20+
#include "constrained_decoding.h"
21+
#include "util/threadpool.h"
22+
23+
namespace xllm {
24+
25+
class RecConstrainedDecoding : public ConstrainedDecoding {
26+
public:
27+
RecConstrainedDecoding(uint64_t model_version,
28+
const int32_t vocab_size,
29+
torch::ScalarType dtype,
30+
torch::Device device,
31+
bool use_gen_threadpool_ = true);
32+
virtual ~RecConstrainedDecoding() = default;
33+
34+
bool build_mask_cache() override;
35+
36+
torch::Tensor generate_mask(
37+
const std::vector<std::vector<int32_t>>& generated_token_list) override;
38+
39+
private:
40+
torch::Tensor generate_decode_mask(
41+
const std::vector<std::vector<int32_t>>& generated_token_list);
42+
43+
private:
44+
constexpr static float PRE_MASK_FACTOR = -10000.0f;
45+
constexpr static int GEN_MASK_THREAD_NUM = 16;
46+
47+
private:
48+
bool build_mask_cache_;
49+
bool use_gen_threadpool_;
50+
int32_t vocab_size_;
51+
uint64_t model_version_;
52+
torch::Device device_;
53+
torch::ScalarType dtype_;
54+
torch::Tensor first_token_mask_;
55+
std::unique_ptr<ThreadPool> gen_threadpool_;
56+
};
57+
58+
} // namespace xllm

0 commit comments

Comments
 (0)