1+ #include " rec_vocab_dict.h"
2+
3+ #include < algorithm>
4+ #include < array>
5+ #include < filesystem>
6+ #include < fstream>
7+
8+ #include " common/global_flags.h"
9+ #include " util/timer.h"
10+
11+ namespace xllm {
12+
13+ bool RecVocabDict::initialize (const std::string& vocab_file) {
14+ if (initialized_) {
15+ return true ;
16+ }
17+
18+ Timer timer;
19+
20+ if (vocab_file.empty ()) {
21+ LOG (ERROR) << " Content data file is empty, file: " << vocab_file;
22+ return false ;
23+ }
24+ if (!std::filesystem::exists (vocab_file)) {
25+ LOG (ERROR) << " Fail to find content data file: " << vocab_file;
26+ return false ;
27+ }
28+ std::ifstream ifs (vocab_file.data (), std::ios::binary | std::ios::ate);
29+ if (!ifs.is_open ()) {
30+ LOG (ERROR) << " Fail to load content data file: " << vocab_file;
31+ return false ;
32+ }
33+
34+ const size_t file_size = ifs.tellg ();
35+ ifs.seekg (0 , std::ios::beg);
36+
37+ // Each line of content : 1 * int64_t(item id) + REC_TOKEN_SIZE *
38+ // int32_t(token id);
39+ const size_t itemid_size = sizeof (int64_t );
40+ const size_t tokens_size = REC_TOKEN_SIZE * sizeof (int32_t );
41+ const size_t line_size = tokens_size + itemid_size;
42+ const size_t estimated_lines = (file_size + line_size - 1 ) / line_size;
43+
44+ // 2 and 4 are only empirical values
45+ item_to_tokens_map_.reserve (estimated_lines);
46+ tokens_to_items_map_.reserve (estimated_lines / 2 );
47+ prefix_tokens_to_next_tokens_map_.reserve (estimated_lines / 4 );
48+
49+ int64_t item_id = 0 ;
50+ RecTokenTriple tokens;
51+
52+ while (ifs.read (reinterpret_cast <char *>(&item_id), itemid_size) &&
53+ ifs.read (reinterpret_cast <char *>(tokens.data ()), tokens_size)) {
54+ if (FLAGS_enable_constrained_decoding) {
55+ for (int i = 0 ; i < tokens.size (); i++) {
56+ std::vector<int32_t > prefix_tokens;
57+
58+ for (int j = 0 ; j < i; j++) {
59+ prefix_tokens.emplace_back (tokens[j]);
60+ }
61+
62+ prefix_tokens_to_next_tokens_map_[prefix_tokens].insert (tokens[i]);
63+ }
64+ }
65+
66+ item_to_tokens_map_[item_id] = tokens;
67+
68+ tokens_to_items_map_[tokens].emplace_back (item_id);
69+ }
70+
71+ if (ifs.gcount () != 0 && ifs.gcount () != line_size) {
72+ LOG (ERROR) << " Possibly containing incomplete lines : " << vocab_file;
73+ item_to_tokens_map_.clear ();
74+ tokens_to_items_map_.clear ();
75+ prefix_tokens_to_next_tokens_map_.clear ();
76+ return false ;
77+ }
78+
79+ initialized_ = true ;
80+ LOG (INFO) << " Total line size:" << estimated_lines
81+ << " ,parse tokens to item id map size: "
82+ << tokens_to_items_map_.size ()
83+ << " , parse item to tokens map size:" << item_to_tokens_map_.size ()
84+ << " , parse prefix tokens to next tokens map size:"
85+ << prefix_tokens_to_next_tokens_map_.size ()
86+ << " , cost: " << timer.elapsed_seconds () << " seconds" ;
87+
88+ return true ;
89+ }
90+
91+ bool RecVocabDict::get_items_by_tokens (const RecTokenTriple& rec_token_triple,
92+ std::vector<int64_t >* item_ids) const {
93+ CHECK_EQ (initialized_, true );
94+ CHECK_NE (item_ids, nullptr );
95+
96+ auto iter = tokens_to_items_map_.find (rec_token_triple);
97+ if (iter == tokens_to_items_map_.end ()) {
98+ return false ;
99+ }
100+
101+ std::copy (
102+ iter->second .begin (), iter->second .end (), std::back_inserter (*item_ids));
103+
104+ return true ;
105+ }
106+
107+ bool RecVocabDict::get_tokens_by_item (int64_t item_id,
108+ std::vector<int32_t >* token_ids) const {
109+ CHECK_EQ (initialized_, true );
110+ CHECK_NE (token_ids, nullptr );
111+
112+ auto iter = item_to_tokens_map_.find (item_id);
113+ if (iter == item_to_tokens_map_.end ()) {
114+ return false ;
115+ }
116+
117+ std::copy (
118+ iter->second .begin (), iter->second .end (), std::back_inserter (*token_ids));
119+
120+ return true ;
121+ }
122+
123+ const std::set<int32_t >& RecVocabDict::get_next_tokens_by_prefix_tokens (
124+ const Slice<int32_t >& prefix_token_ids) const {
125+ CHECK_EQ (initialized_, true );
126+ CHECK_LT (prefix_token_ids.size (), REC_TOKEN_SIZE);
127+
128+ std::vector<int32_t > prefix_tokens_ids_vec = prefix_token_ids;
129+ auto iter = prefix_tokens_to_next_tokens_map_.find (prefix_tokens_ids_vec);
130+ if (iter == prefix_tokens_to_next_tokens_map_.end ()) {
131+ static std::set<int32_t > empty_set;
132+ return empty_set;
133+ }
134+
135+ return iter->second ;
136+ }
137+
138+ } // namespace xllm
0 commit comments