Skip to content

Commit 257c867

Browse files
feat: add generative recommendation tokenizer. (jd-opensource#317)
* feat: add generative recommendation tokenizer. * feat: add xllm header,fix log style,etc. * feat: add xllm header,fix annotation style,etc. * feat: fix log style,etc.
1 parent ed437a6 commit 257c867

16 files changed

+560
-7
lines changed

xllm/core/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cc_library(
1616
rate_limiter.h
1717
types.h
1818
device_monitor.h
19+
version_singleton.h
1920
SRCS
2021
etcd_client.cpp
2122
global_flags.cpp

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,9 @@ DEFINE_int64(dit_cache_skip_interval_steps,
458458
DEFINE_double(dit_cache_residual_diff_threshold,
459459
0.09f,
460460
"The residual difference threshold for cache reuse.");
461+
462+
DEFINE_bool(enable_constrained_decoding,
463+
false,
464+
"Whether to enable constrained decoding, which is used to ensure "
465+
"that the output meets specific format or structural requirements "
466+
"through pre-defined rules.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,5 @@ DECLARE_int64(dit_cache_n_derivatives);
224224
DECLARE_int64(dit_cache_skip_interval_steps);
225225

226226
DECLARE_double(dit_cache_residual_diff_threshold);
227+
228+
DECLARE_bool(enable_constrained_decoding);

xllm/core/common/types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,4 +289,7 @@ struct EplbInfo {
289289
int32_t update_layer_id = -1;
290290
};
291291

292+
inline constexpr int REC_TOKEN_SIZE = 3;
293+
294+
using RecTokenTriple = std::array<int32_t, REC_TOKEN_SIZE>;
292295
} // namespace xllm
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
19+
#include <list>
20+
#include <memory>
21+
#include <mutex>
22+
#include <shared_mutex>
23+
#include <string>
24+
#include <unordered_map>
25+
#include <vector>
26+
27+
namespace xllm {
28+
// a singleton mode by version
29+
template <typename T>
30+
class VersionSingleton {
31+
public:
32+
template <typename... Args>
33+
static T* GetInstance(const std::string& version,
34+
bool delete_old_versions = true,
35+
int reserved_version_size =
36+
2, // default retention of the last two versions
37+
Args&&... args) {
38+
T* instance = nullptr;
39+
40+
{
41+
std::shared_lock<std::shared_mutex> lock(instance_map_mutex_);
42+
auto it = instance_map_.find(version);
43+
if (it != instance_map_.end()) {
44+
instance = it->second.get();
45+
}
46+
}
47+
48+
if (instance == nullptr) {
49+
std::unique_lock<std::shared_mutex> lock(instance_map_mutex_);
50+
51+
auto it = instance_map_.find(version);
52+
if (it == instance_map_.end()) {
53+
instance = new T(std::forward<Args>(args)...);
54+
instance_map_[version] = std::unique_ptr<T>(instance);
55+
instance_version_list_.push_front(version);
56+
if (delete_old_versions) {
57+
if (instance_version_list_.size() > reserved_version_size) {
58+
auto it = instance_version_list_.begin();
59+
std::advance(it, reserved_version_size);
60+
for (; it != instance_version_list_.end(); it++) {
61+
instance_map_.erase(*it);
62+
}
63+
instance_version_list_.resize(reserved_version_size);
64+
}
65+
}
66+
} else {
67+
instance = it->second.get();
68+
}
69+
}
70+
71+
return instance;
72+
}
73+
74+
static std::vector<std::string> GetVersions() {
75+
std::lock_guard<std::mutex> lock(instance_map_mutex_);
76+
std::vector<std::string> versions;
77+
for (const auto& pair : instance_map_) {
78+
versions.push_back(pair.first);
79+
}
80+
return versions;
81+
}
82+
83+
static void DestroyAllInstances() {
84+
std::lock_guard<std::mutex> lock(instance_map_mutex_);
85+
instance_map_.clear();
86+
instance_version_list_.clear();
87+
}
88+
89+
VersionSingleton(const VersionSingleton&) = delete;
90+
VersionSingleton& operator=(const VersionSingleton&) = delete;
91+
92+
private:
93+
VersionSingleton() = default;
94+
~VersionSingleton() = default;
95+
96+
static std::unordered_map<std::string, std::unique_ptr<T>> instance_map_;
97+
static std::list<std::string> instance_version_list_;
98+
static std::shared_mutex instance_map_mutex_;
99+
};
100+
101+
template <typename T>
102+
std::unordered_map<std::string, std::unique_ptr<T>>
103+
VersionSingleton<T>::instance_map_;
104+
template <typename T>
105+
std::list<std::string> VersionSingleton<T>::instance_version_list_;
106+
template <typename T>
107+
std::shared_mutex VersionSingleton<T>::instance_map_mutex_;
108+
109+
} // namespace xllm

xllm/core/framework/hf_model_loader.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ limitations under the License.
2525
#include <filesystem>
2626
#include <vector>
2727

28+
#include "core/common/version_singleton.h"
29+
#include "core/framework/state_dict/rec_vocab_dict.h"
2830
#include "core/framework/tokenizer/fast_tokenizer.h"
31+
#include "core/framework/tokenizer/rec_tokenizer.h"
2932
#include "core/framework/tokenizer/sentencepiece_tokenizer.h"
3033
#include "core/framework/tokenizer/tiktoken_tokenizer.h"
3134
#include "core/framework/tokenizer/tokenizer_factory.h"
@@ -51,7 +54,13 @@ HFModelLoader::HFModelLoader(const std::string& model_weights_path)
5154
<< "Failed to find model weights files in " << model_weights_path;
5255
// sort the model weights files by name
5356
std::sort(model_weights_files_.begin(), model_weights_files_.end());
57+
5458
threadpool_ = std::make_unique<ThreadPool>(32);
59+
60+
if (FLAGS_backend == "rec") {
61+
CHECK(load_rec_vocab(model_weights_path))
62+
<< "Failed to load rec content from " << model_weights_path;
63+
}
5564
}
5665

5766
std::unique_ptr<Tokenizer> HFModelLoader::tokenizer() const {
@@ -80,6 +89,28 @@ std::vector<std::unique_ptr<StateDict>>& HFModelLoader::get_state_dicts() {
8089
return state_dicts_;
8190
}
8291

92+
bool HFModelLoader::load_rec_vocab(const std::string& model_weights_path) {
93+
if (!tokenizer_args_.vocab_file().empty()) {
94+
std::filesystem::path path = model_weights_path;
95+
std::string model_version = path.filename();
96+
std::string vocab_full_path =
97+
path.append(tokenizer_args_.vocab_file()).string();
98+
99+
LOG(INFO) << "Model_version: " << model_version
100+
<< ", vocab_full_path: " << vocab_full_path;
101+
102+
CHECK(nullptr != VersionSingleton<RecVocabDict>::GetInstance(model_version))
103+
<< "Failed to get vocab dict instance";
104+
CHECK(VersionSingleton<RecVocabDict>::GetInstance(model_version)
105+
->initialize(vocab_full_path))
106+
<< "Failed to initialize vocab dict from " << vocab_full_path;
107+
} else {
108+
LOG(ERROR) << "Vocab file is not set";
109+
}
110+
111+
return true;
112+
}
113+
83114
bool HFModelLoader::load_args(const std::string& model_weights_path) {
84115
if (!load_model_args(model_weights_path)) {
85116
LOG(ERROR) << "Failed to load model args from " << model_weights_path;

xllm/core/framework/hf_model_loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class HFModelLoader : public ModelLoader {
3636

3737
private:
3838
bool load_args(const std::string& model_weights_path);
39+
bool load_rec_vocab(const std::string& model_weights_path);
3940
bool load_model_args(const std::string& model_weights_path);
4041
bool load_quant_args(const std::string& model_weights_path);
4142
bool load_tokenizer_args(const std::string& model_weights_path);

xllm/core/framework/state_dict/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ cc_library(
1111
HDRS
1212
state_dict.h
1313
utils.h
14+
rec_vocab_dict.h
1415
SRCS
1516
state_dict.cpp
1617
utils.cpp
18+
rec_vocab_dict.cpp
1719
DEPS
1820
rust_safetensors
1921
torch
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include "rec_vocab_dict.h"
2+
3+
#include <algorithm>
4+
#include <array>
5+
#include <filesystem>
6+
#include <fstream>
7+
8+
#include "common/global_flags.h"
9+
#include "util/timer.h"
10+
11+
namespace xllm {
12+
13+
bool RecVocabDict::initialize(const std::string& vocab_file) {
14+
if (initialized_) {
15+
return true;
16+
}
17+
18+
Timer timer;
19+
20+
if (vocab_file.empty()) {
21+
LOG(ERROR) << "Content data file is empty, file: " << vocab_file;
22+
return false;
23+
}
24+
if (!std::filesystem::exists(vocab_file)) {
25+
LOG(ERROR) << "Fail to find content data file: " << vocab_file;
26+
return false;
27+
}
28+
std::ifstream ifs(vocab_file.data(), std::ios::binary | std::ios::ate);
29+
if (!ifs.is_open()) {
30+
LOG(ERROR) << "Fail to load content data file: " << vocab_file;
31+
return false;
32+
}
33+
34+
const size_t file_size = ifs.tellg();
35+
ifs.seekg(0, std::ios::beg);
36+
37+
// Each line of content : 1 * int64_t(item id) + REC_TOKEN_SIZE *
38+
// int32_t(token id);
39+
const size_t itemid_size = sizeof(int64_t);
40+
const size_t tokens_size = REC_TOKEN_SIZE * sizeof(int32_t);
41+
const size_t line_size = tokens_size + itemid_size;
42+
const size_t estimated_lines = (file_size + line_size - 1) / line_size;
43+
44+
// 2 and 4 are only empirical values
45+
item_to_tokens_map_.reserve(estimated_lines);
46+
tokens_to_items_map_.reserve(estimated_lines / 2);
47+
prefix_tokens_to_next_tokens_map_.reserve(estimated_lines / 4);
48+
49+
int64_t item_id = 0;
50+
RecTokenTriple tokens;
51+
52+
while (ifs.read(reinterpret_cast<char*>(&item_id), itemid_size) &&
53+
ifs.read(reinterpret_cast<char*>(tokens.data()), tokens_size)) {
54+
if (FLAGS_enable_constrained_decoding) {
55+
for (int i = 0; i < tokens.size(); i++) {
56+
std::vector<int32_t> prefix_tokens;
57+
58+
for (int j = 0; j < i; j++) {
59+
prefix_tokens.emplace_back(tokens[j]);
60+
}
61+
62+
prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]);
63+
}
64+
}
65+
66+
item_to_tokens_map_[item_id] = tokens;
67+
68+
tokens_to_items_map_[tokens].emplace_back(item_id);
69+
}
70+
71+
if (ifs.gcount() != 0 && ifs.gcount() != line_size) {
72+
LOG(ERROR) << "Possibly containing incomplete lines : " << vocab_file;
73+
item_to_tokens_map_.clear();
74+
tokens_to_items_map_.clear();
75+
prefix_tokens_to_next_tokens_map_.clear();
76+
return false;
77+
}
78+
79+
initialized_ = true;
80+
LOG(INFO) << "Total line size:" << estimated_lines
81+
<< ",parse tokens to item id map size: "
82+
<< tokens_to_items_map_.size()
83+
<< ", parse item to tokens map size:" << item_to_tokens_map_.size()
84+
<< ", parse prefix tokens to next tokens map size:"
85+
<< prefix_tokens_to_next_tokens_map_.size()
86+
<< ", cost: " << timer.elapsed_seconds() << " seconds";
87+
88+
return true;
89+
}
90+
91+
bool RecVocabDict::get_items_by_tokens(const RecTokenTriple& rec_token_triple,
92+
std::vector<int64_t>* item_ids) const {
93+
CHECK_EQ(initialized_, true);
94+
CHECK_NE(item_ids, nullptr);
95+
96+
auto iter = tokens_to_items_map_.find(rec_token_triple);
97+
if (iter == tokens_to_items_map_.end()) {
98+
return false;
99+
}
100+
101+
std::copy(
102+
iter->second.begin(), iter->second.end(), std::back_inserter(*item_ids));
103+
104+
return true;
105+
}
106+
107+
bool RecVocabDict::get_tokens_by_item(int64_t item_id,
108+
std::vector<int32_t>* token_ids) const {
109+
CHECK_EQ(initialized_, true);
110+
CHECK_NE(token_ids, nullptr);
111+
112+
auto iter = item_to_tokens_map_.find(item_id);
113+
if (iter == item_to_tokens_map_.end()) {
114+
return false;
115+
}
116+
117+
std::copy(
118+
iter->second.begin(), iter->second.end(), std::back_inserter(*token_ids));
119+
120+
return true;
121+
}
122+
123+
const std::set<int32_t>& RecVocabDict::get_next_tokens_by_prefix_tokens(
124+
const Slice<int32_t>& prefix_token_ids) const {
125+
CHECK_EQ(initialized_, true);
126+
CHECK_LT(prefix_token_ids.size(), REC_TOKEN_SIZE);
127+
128+
std::vector<int32_t> prefix_tokens_ids_vec = prefix_token_ids;
129+
auto iter = prefix_tokens_to_next_tokens_map_.find(prefix_tokens_ids_vec);
130+
if (iter == prefix_tokens_to_next_tokens_map_.end()) {
131+
static std::set<int32_t> empty_set;
132+
return empty_set;
133+
}
134+
135+
return iter->second;
136+
}
137+
138+
} // namespace xllm

0 commit comments

Comments
 (0)