33from pathlib import Path
44from typing import TYPE_CHECKING , Any , cast , overload
55
6+ import llguidance as llg
7+ import regex as re
68from mistral_common .protocol .instruct .request import (
79 ChatCompletionRequest as MistralChatCompletionRequest ,
810)
1113from mistral_common .tokens .tokenizers .base import (
1214 SpecialTokenPolicy ,
1315 SpecialTokens ,
16+ Tokenizer ,
17+ )
18+ from mistral_common .tokens .tokenizers .instruct import (
19+ InstructTokenizerBase ,
20+ InstructTokenizerV13 ,
21+ )
22+ from mistral_common .tokens .tokenizers .mistral import (
23+ MistralTokenizer as MistralCommonTokenizer ,
1424)
15- from mistral_common .tokens .tokenizers .instruct import InstructTokenizerV13
1625from mistral_common .tokens .tokenizers .sentencepiece import (
1726 SentencePieceTokenizer ,
1827)
2130from vllm .entrypoints .chat_utils import ChatCompletionMessageParam
2231from vllm .entrypoints .openai .chat_completion .protocol import ChatCompletionRequest
2332from vllm .logger import init_logger
33+ from vllm .tokenizers .protocol import TokenizerLike
34+
35+ try :
36+ # Transformers v5
37+ from transformers .tokenization_mistral_common import MistralCommonBackend
38+ except ImportError :
39+ # Transformers v4
40+ from transformers .tokenization_mistral_common import (
41+ MistralCommonTokenizer as MistralCommonBackend ,
42+ )
2443
25- from .protocol import TokenizerLike
2644
2745if TYPE_CHECKING :
2846 from transformers import BatchEncoding
2947
30- try :
31- # Transformers v5
32- from transformers .tokenization_mistral_common import MistralCommonBackend
33- except ImportError :
34- # Transformers v4
35- from transformers .tokenization_mistral_common import (
36- MistralCommonTokenizer as MistralCommonBackend ,
37- )
3848
3949logger = init_logger (__name__ )
4050
@@ -217,15 +227,6 @@ def from_pretrained(
217227 download_dir : str | None = None ,
218228 ** kwargs ,
219229 ) -> "MistralTokenizer" :
220- try :
221- # Transformers v5
222- from transformers .tokenization_mistral_common import MistralCommonBackend
223- except ImportError :
224- # Transformers v4
225- from transformers .tokenization_mistral_common import (
226- MistralCommonTokenizer as MistralCommonBackend ,
227- )
228-
229230 tokenizer = MistralCommonBackend .from_pretrained (
230231 path_or_repo_id ,
231232 * args ,
@@ -240,10 +241,10 @@ def from_pretrained(
240241 def __init__ (self , tokenizer : "MistralCommonBackend" ) -> None :
241242 super ().__init__ ()
242243
243- self .transformers_tokenizer = tokenizer
244- self .mistral = tokenizer .tokenizer
245- self .instruct = self .mistral .instruct_tokenizer
246- self .tokenizer = self .instruct .tokenizer
244+ self .transformers_tokenizer : MistralCommonBackend = tokenizer
245+ self .mistral : MistralCommonTokenizer = tokenizer .tokenizer
246+ self .instruct : InstructTokenizerBase = self .mistral .instruct_tokenizer
247+ self .tokenizer : Tokenizer = self .instruct .tokenizer
247248
248249 mode = self .mistral ._chat_completion_request_validator ._mode
249250 if mode != ValidationMode .test :
@@ -509,7 +510,7 @@ def convert_ids_to_tokens(
509510 return [self .tokenizer .id_to_piece (token_id ) for token_id in ids ]
510511
511512 non_skip_special_tokens_ids = {
512- self .tokenizer .get_control_token (SpecialTokens .tool_calls ),
513+ self .tokenizer .get_special_token (SpecialTokens .tool_calls ),
513514 }
514515 if isinstance (self .instruct , InstructTokenizerV13 ):
515516 if self .instruct .BEGIN_THINK :
@@ -541,3 +542,66 @@ def convert_ids_to_tokens(
541542 ]
542543
543544 return tokens
545+
546+
547+ class MistralLLGTokenizer :
548+ """Wraps a mistral tokenizer for use with llguidance."""
549+
550+ eos_token_id : int
551+ bos_token_id : int
552+ tokens : list [bytes ]
553+ special_token_ids : list [int ]
554+
555+ def __init__ (self , tokenizer : MistralTokenizer ) -> None :
556+ self ._tokenizer = tokenizer .tokenizer
557+ self .eos_token_id = self ._tokenizer .eos_id
558+ self .bos_token_id = self ._tokenizer .bos_id
559+
560+ self .tokens : list [bytes ] = []
561+ self .special_token_ids : list [int ] = []
562+
563+ seen_special_tokens : set [str ] = set ()
564+ for i in range (self ._tokenizer .n_words ):
565+ # Convert square brackets to angle brackets for special tokens,
566+ # since llg only recognizes the latter.
567+ if self ._tokenizer .is_special (i ):
568+ token_rep = self ._tokenizer .id_to_piece (i )
569+ if match := re .fullmatch (r"\[(.*)\]" , token_rep ):
570+ token_rep_llg = f"<{ match .group (1 )} >"
571+ else :
572+ token_rep_llg = token_rep
573+
574+ if not re .fullmatch (r"<.*>" , token_rep_llg ):
575+ raise ValueError (
576+ f"Invalid special token: { token_rep_llg } ({ token_rep } )"
577+ )
578+ assert token_rep_llg not in seen_special_tokens , (
579+ token_rep_llg ,
580+ seen_special_tokens ,
581+ )
582+ seen_special_tokens .add (token_rep_llg )
583+ self .special_token_ids .append (i )
584+ self .tokens .append (token_rep_llg .encode ("utf-8" ))
585+ else :
586+ token = self ._tokenizer .id_to_byte_piece (i , SpecialTokenPolicy .RAISE )
587+ self .tokens .append (token )
588+
589+ assert len (self .special_token_ids ) == self ._tokenizer .num_special_tokens , (
590+ len (self .special_token_ids ),
591+ self ._tokenizer .num_special_tokens ,
592+ )
593+
594+ def __call__ (self , s : str , * args , ** kwds ) -> list [int ]:
595+ # HACK: add a null byte to the start of the string to avoid the tokenizer
596+ # absorbing the first character of tokens that start with "▁".
597+ # we then ignore the first two tokens the "▁" and the null byte.
598+ # This gives us the pure tokenized text without SP shit.
599+ if isinstance (self ._tokenizer , SentencePieceTokenizer ):
600+ return self ._tokenizer .encode ("\00 " + s , bos = False , eos = False )[2 :]
601+ else :
602+ return self ._tokenizer .encode (s , bos = False , eos = False )
603+
604+
605+ def guidance_tokenizer_from_mistral_tokenizer (tokenizer : Tokenizer ) -> llg .LLTokenizer :
606+ tokenizer_data = MistralLLGTokenizer (tokenizer )
607+ return llg .LLTokenizer (llg .TokenizerWrapper (tokenizer_data ))
0 commit comments