Skip to content

Commit a9e88dd

Browse files
committed
perf(core): optimize BPE lookup/byte_map and ensure thread safety
1 parent 37ecc22 commit a9e88dd

File tree

1 file changed

+35
-11
lines changed

1 file changed

+35
-11
lines changed

src/tokenizer.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <oniguruma.h>
1717
#include <utf8proc/utf8proc.h>
1818
#include <iostream>
19+
#include <mutex>
1920
#include "ujson.hpp"
2021
#include "jinja.hpp"
2122

@@ -109,21 +110,23 @@ static std::string get_token_content(const json& j) {
109110
return "";
110111
}
111112

112-
static std::unordered_map<unsigned char, std::string> create_bytes_char_map() {
113+
static std::vector<std::string> create_bytes_char_map() {
113114
auto u2u = [](int cp) -> std::string {
114115
std::string out;
115116
if (cp <= 0x7F) out += (char)cp;
116117
else if (cp <= 0x7FF) { out += (char)(0xC0 | (cp >> 6)); out += (char)(0x80 | (cp & 0x3F)); }
117118
else if (cp <= 0xFFFF) { out += (char)(0xE0 | (cp >> 12)); out += (char)(0x80 | ((cp >> 6) & 0x3F)); out += (char)(0x80 | (cp & 0x3F)); }
118119
return out;
119120
};
120-
std::unordered_map<unsigned char, std::string> bs;
121-
for (int b = 33; b <= 126; ++b) bs[(unsigned char)b] = u2u(b);
122-
for (int b = 161; b <= 172; ++b) bs[(unsigned char)b] = u2u(b);
123-
for (int b = 174; b <= 255; ++b) bs[(unsigned char)b] = u2u(b);
121+
std::vector<std::string> bs(256);
122+
std::unordered_map<unsigned char, std::string> temp_bs;
123+
for (int b = 33; b <= 126; ++b) temp_bs[(unsigned char)b] = u2u(b);
124+
for (int b = 161; b <= 172; ++b) temp_bs[(unsigned char)b] = u2u(b);
125+
for (int b = 174; b <= 255; ++b) temp_bs[(unsigned char)b] = u2u(b);
124126
int n = 0;
125127
for (int b = 0; b < 256; ++b) {
126-
if (bs.find((unsigned char)b) == bs.end()) bs[(unsigned char)b] = u2u(256 + n++);
128+
if (temp_bs.find((unsigned char)b) == temp_bs.end()) temp_bs[(unsigned char)b] = u2u(256 + n++);
129+
bs[b] = temp_bs[(unsigned char)b];
127130
}
128131
return bs;
129132
}
@@ -499,12 +502,23 @@ class BertPreTokenizer : public PreTokenizer {
499502

500503
// Moved create_bytes_char_map up
501504

505+
struct PairHash {
506+
inline size_t operator()(const std::pair<int, int>& v) const {
507+
// IntPairHash from boost (simplified)
508+
size_t seed = 0;
509+
seed ^= std::hash<int>{}(v.first) + 0x9e3779b9 + (seed<<6) + (seed>>2);
510+
seed ^= std::hash<int>{}(v.second) + 0x9e3779b9 + (seed<<6) + (seed>>2);
511+
return seed;
512+
}
513+
};
514+
502515
class BPEModel : public Model {
503516
public:
504517
bool use_byte_level_;
505518
std::unordered_map<std::string, int> vocab_;
506519
std::unordered_map<int, std::string> id_to_token_;
507-
std::map<std::pair<int, int>, int> merges_;
520+
std::unordered_map<std::pair<int, int>, int, PairHash> merges_;
521+
mutable std::mutex cache_mutex_;
508522
mutable std::unordered_map<std::string, std::vector<int>> cache_;
509523

510524
BPEModel(const std::map<std::string, int>& vocab,
@@ -529,8 +543,12 @@ class BPEModel : public Model {
529543

530544
std::vector<int> tokenize(const std::string& text) const override {
531545
if (text.empty()) return {};
532-
auto cit = cache_.find(text);
533-
if (cit != cache_.end()) return cit->second;
546+
{
547+
std::lock_guard<std::mutex> lock(cache_mutex_);
548+
auto cit = cache_.find(text);
549+
if (cit != cache_.end()) return cit->second;
550+
}
551+
534552
std::vector<int> out;
535553
if (use_byte_level_) {
536554
static auto byte_map = create_bytes_char_map();
@@ -572,7 +590,10 @@ class BPEModel : public Model {
572590
int nid = token_to_id(m); if (nid == -1) break;
573591
out[best] = nid; out.erase(out.begin() + best + 1);
574592
}
575-
cache_[text] = out;
593+
{
594+
std::lock_guard<std::mutex> lock(cache_mutex_);
595+
cache_[text] = out;
596+
}
576597
return out;
577598
}
578599

@@ -874,7 +895,10 @@ class ByteLevelDecoder : public Decoder {
874895
void decode(std::vector<std::string>& tokens) const override {
875896
static auto bm = []() {
876897
std::unordered_map<std::string, unsigned char> m;
877-
for (const auto& p : create_bytes_char_map()) m[p.second] = p.first;
898+
auto byte_vec = create_bytes_char_map();
899+
for (int i = 0; i < 256; ++i) {
900+
if (!byte_vec[i].empty()) m[byte_vec[i]] = (unsigned char)i;
901+
}
878902
return m;
879903
}();
880904
for (auto& t : tokens) {

0 commit comments

Comments
 (0)