55from ..util import synchronized
66from ..util .file import maybe_read_json
77from ..model .config import Config
8- from functools import lru_cache
8+ from functools import lru_cache , cached_property
99from typing import TYPE_CHECKING
1010from ..util import profile_opt
1111if TYPE_CHECKING :
@@ -162,7 +162,7 @@ def get_default_token_id(config_key: str, current: int | None, default: int | No
162162 self .actual_vocab_size = 1 + max (
163163 list (self .extended_id_to_piece .keys ()) + \
164164 list (self .unspecial_id_to_piece .keys ()) + \
165- [self .raw_vocab_size () - 1 ]
165+ [self .raw_vocab_size - 1 ]
166166 )
167167
168168 # Useful token IDs
@@ -185,7 +185,7 @@ def get_default_token_id(config_key: str, current: int | None, default: int | No
185185 self .get_id_to_piece_list (False )
186186 self .get_piece_to_id_dict ()
187187
188- @lru_cache
188+ @cached_property
189189 def raw_vocab_size (self ):
190190 """
191191 Cache this function because it's suspiciously slow in HF Tokenizers
@@ -383,14 +383,14 @@ def decode_(self, seq, decode_special_tokens):
383383
384384 if not decode_special_tokens :
385385
386- max_token = self .raw_vocab_size ()
386+ max_token = self .raw_vocab_size
387387 seq = [t for t in seq if (t != self .pad_token_id and t < max_token and t != self .eos_token_id )]
388388 if self .eos_token_id in seq : seq = seq [:seq .index (self .eos_token_id )]
389389 return self .decode_unspecial (seq )
390390
391391 else :
392392
393- max_token = self .raw_vocab_size ()
393+ max_token = self .raw_vocab_size
394394 seq = [t for t in seq if (t != self .pad_token_id and t < max_token )]
395395 text = ""
396396 start = 0
@@ -486,11 +486,10 @@ def num_tokens(self, text):
486486
487487 # Get ordinals of single-byte tokens
488488
489- @synchronized
490- @lru_cache
491- def get_id_to_ord_list (self ):
489+ @cached_property
490+ def _get_id_to_ord_list (self ):
492491
493- self .id_to_ord = list (range (self .raw_vocab_size () ))
492+ self .id_to_ord = list (range (self .raw_vocab_size ))
494493
495494 def clean_special_chars (p ):
496495 p = p .replace (self .space_char_ , " " )
@@ -508,7 +507,7 @@ def piece_to_ord(p):
508507 if o <= 255 : return o
509508 return - 1
510509
511- i = self .raw_vocab_size ()
510+ i = self .raw_vocab_size
512511 while True :
513512 if i in self .extended_id_to_piece :
514513 self .id_to_ord .append (piece_to_ord (self .extended_id_to_piece [i ]))
@@ -520,10 +519,15 @@ def piece_to_ord(p):
520519 break
521520 i += 1
522521
522+ return self .id_to_ord
523+ @synchronized
524+ def get_id_to_ord_list (self ):
525+ return self ._get_id_to_ord_list
526+
523527 # Copy vocabulary from model
524528
525- @lru_cache
526- def get_fixed_vocab (self ):
529+ @cached_property
530+ def _get_fixed_vocab (self ):
527531 test_enc = self .tokenizer .encode (" t" , add_special_tokens = False )
528532 test_count = len (test_enc .ids )
529533 assert test_count > 0 , "Tokenizer error, test string encodes to zero tokens"
@@ -532,35 +536,37 @@ def get_fixed_vocab(self):
532536
533537 if test_count == 1 and len (test_piece ) == len (" t" ):
534538 vocab = self .tokenizer .decode_batch (
535- [[i ] for i in range (self .raw_vocab_size () )],
539+ [[i ] for i in range (self .raw_vocab_size )],
536540 skip_special_tokens = False
537541 )
538542 else :
539543 prefix_id = self .tokenizer .encode (" " , add_special_tokens = False ).ids [0 ]
540544 prefix_piece = self .tokenizer .decode ([prefix_id ])
541545 prefix_len = len (prefix_piece )
542546 vocab = self .tokenizer .decode_batch (
543- [[prefix_id , i ] for i in range (self .raw_vocab_size () )]
547+ [[prefix_id , i ] for i in range (self .raw_vocab_size )]
544548 )
545549 vocab = [v [prefix_len :] for v in vocab ]
546550
547551 return vocab
552+ def get_fixed_vocab (self ):
553+ return self ._get_fixed_vocab
548554
549- @synchronized
550- @lru_cache
551- def get_id_to_piece_list (self , include_special_tokens = False ):
555+ @cached_property
556+ def _get_id_to_piece_list_spc (self , include_special_tokens = False ):
552557
553- if include_special_tokens :
554- id_to_piece_extended = self .get_id_to_piece_list ().copy ()
555- for k , v in self .extended_id_to_piece .items ():
556- id_to_piece_extended [k ] = v
558+ id_to_piece_extended = self .get_id_to_piece_list ().copy ()
559+ for k , v in self .extended_id_to_piece .items ():
560+ id_to_piece_extended [k ] = v
557561
558- self .id_to_piece_with_special = id_to_piece_extended
559- return self .id_to_piece_with_special
562+ self .id_to_piece_with_special = id_to_piece_extended
563+ return self .id_to_piece_with_special
564+ @cached_property
565+ def _get_id_to_piece_list_nonspc (self , include_special_tokens = False ):
560566
561567 self .id_to_piece = self .get_fixed_vocab ()
562568
563- i = self .raw_vocab_size ()
569+ i = self .raw_vocab_size
564570 while True :
565571 if i in self .extended_id_to_piece :
566572 self .id_to_piece .append (self .extended_id_to_piece [i ])
@@ -573,13 +579,21 @@ def get_id_to_piece_list(self, include_special_tokens = False):
573579 i += 1
574580
575581 return self .id_to_piece
576-
577582 @synchronized
578- @lru_cache
579- def get_piece_to_id_dict (self ):
583+ def get_id_to_piece_list (self , include_special_tokens = False ):
584+ if include_special_tokens :
585+ return self ._get_id_to_piece_list_spc
586+ else :
587+ return self ._get_id_to_piece_list_nonspc
588+
589+ @cached_property
590+ def _get_piece_to_id_dict (self ):
580591 all_pieces = self .get_id_to_piece_list ()
581592 self .piece_to_id = {piece : idx for idx , piece in enumerate (all_pieces )}
582593 return self .piece_to_id
594+ @synchronized
595+ def get_piece_to_id_dict (self ):
596+ return self ._get_piece_to_id_dict
583597
584598 @staticmethod
585599 def from_config (config : Config ):
@@ -603,12 +617,14 @@ def get_tokens_with_prefix_id(self, prefix_id: int):
603617 prefix = id_to_piece [prefix_id ]
604618 return self .get_tokens_with_prefix_string (prefix )
605619
606- @lru_cache
607- def get_vocab_dict (self ):
620+ @cached_property
621+ def _get_vocab_dict (self ):
608622 """
609623 Return tokenizer (dictionary for Formatron)
610624 """
611625 return {
612626 self .tokenizer .id_to_token (i ): i
613627 for i in range (self .tokenizer .get_vocab_size ())
614628 }
629+ def get_vocab_dict (self ):
630+ return self ._get_vocab_dict
0 commit comments