1616#include < oniguruma.h>
1717#include < utf8proc/utf8proc.h>
1818#include < iostream>
19+ #include < mutex>
1920#include " ujson.hpp"
2021#include " jinja.hpp"
2122
@@ -109,21 +110,23 @@ static std::string get_token_content(const json& j) {
109110 return " " ;
110111}
111112
112- static std::unordered_map< unsigned char , std::string> create_bytes_char_map () {
113+ static std::vector< std::string> create_bytes_char_map () {
113114 auto u2u = [](int cp) -> std::string {
114115 std::string out;
115116 if (cp <= 0x7F ) out += (char )cp;
116117 else if (cp <= 0x7FF ) { out += (char )(0xC0 | (cp >> 6 )); out += (char )(0x80 | (cp & 0x3F )); }
117118 else if (cp <= 0xFFFF ) { out += (char )(0xE0 | (cp >> 12 )); out += (char )(0x80 | ((cp >> 6 ) & 0x3F )); out += (char )(0x80 | (cp & 0x3F )); }
118119 return out;
119120 };
120- std::unordered_map<unsigned char , std::string> bs;
121- for (int b = 33 ; b <= 126 ; ++b) bs[(unsigned char )b] = u2u (b);
122- for (int b = 161 ; b <= 172 ; ++b) bs[(unsigned char )b] = u2u (b);
123- for (int b = 174 ; b <= 255 ; ++b) bs[(unsigned char )b] = u2u (b);
121+ std::vector<std::string> bs (256 );
122+ std::unordered_map<unsigned char , std::string> temp_bs;
123+ for (int b = 33 ; b <= 126 ; ++b) temp_bs[(unsigned char )b] = u2u (b);
124+ for (int b = 161 ; b <= 172 ; ++b) temp_bs[(unsigned char )b] = u2u (b);
125+ for (int b = 174 ; b <= 255 ; ++b) temp_bs[(unsigned char )b] = u2u (b);
124126 int n = 0 ;
125127 for (int b = 0 ; b < 256 ; ++b) {
126- if (bs.find ((unsigned char )b) == bs.end ()) bs[(unsigned char )b] = u2u (256 + n++);
128+ if (temp_bs.find ((unsigned char )b) == temp_bs.end ()) temp_bs[(unsigned char )b] = u2u (256 + n++);
129+ bs[b] = temp_bs[(unsigned char )b];
127130 }
128131 return bs;
129132}
@@ -499,12 +502,23 @@ class BertPreTokenizer : public PreTokenizer {
499502
500503// Moved create_bytes_char_map up
501504
505+ struct PairHash {
506+ inline size_t operator ()(const std::pair<int , int >& v) const {
507+ // IntPairHash from boost (simplified)
508+ size_t seed = 0 ;
509+ seed ^= std::hash<int >{}(v.first ) + 0x9e3779b9 + (seed<<6 ) + (seed>>2 );
510+ seed ^= std::hash<int >{}(v.second ) + 0x9e3779b9 + (seed<<6 ) + (seed>>2 );
511+ return seed;
512+ }
513+ };
514+
502515class BPEModel : public Model {
503516public:
504517 bool use_byte_level_;
505518 std::unordered_map<std::string, int > vocab_;
506519 std::unordered_map<int , std::string> id_to_token_;
507- std::map<std::pair<int , int >, int > merges_;
520+ std::unordered_map<std::pair<int , int >, int , PairHash> merges_;
521+ mutable std::mutex cache_mutex_;
508522 mutable std::unordered_map<std::string, std::vector<int >> cache_;
509523
510524 BPEModel (const std::map<std::string, int >& vocab,
@@ -529,8 +543,12 @@ class BPEModel : public Model {
529543
530544 std::vector<int > tokenize (const std::string& text) const override {
531545 if (text.empty ()) return {};
532- auto cit = cache_.find (text);
533- if (cit != cache_.end ()) return cit->second ;
546+ {
547+ std::lock_guard<std::mutex> lock (cache_mutex_);
548+ auto cit = cache_.find (text);
549+ if (cit != cache_.end ()) return cit->second ;
550+ }
551+
534552 std::vector<int > out;
535553 if (use_byte_level_) {
536554 static auto byte_map = create_bytes_char_map ();
@@ -572,7 +590,10 @@ class BPEModel : public Model {
572590 int nid = token_to_id (m); if (nid == -1 ) break ;
573591 out[best] = nid; out.erase (out.begin () + best + 1 );
574592 }
575- cache_[text] = out;
593+ {
594+ std::lock_guard<std::mutex> lock (cache_mutex_);
595+ cache_[text] = out;
596+ }
576597 return out;
577598 }
578599
@@ -874,7 +895,10 @@ class ByteLevelDecoder : public Decoder {
874895 void decode (std::vector<std::string>& tokens) const override {
875896 static auto bm = []() {
876897 std::unordered_map<std::string, unsigned char > m;
877- for (const auto & p : create_bytes_char_map ()) m[p.second ] = p.first ;
898+ auto byte_vec = create_bytes_char_map ();
899+ for (int i = 0 ; i < 256 ; ++i) {
900+ if (!byte_vec[i].empty ()) m[byte_vec[i]] = (unsigned char )i;
901+ }
878902 return m;
879903 }();
880904 for (auto & t : tokens) {
0 commit comments