1111 InvalidScoreLogitsProcessor ,
1212 StoppingCriteriaList ,
1313 StopAtSpecificTokenCriteria ,
14+ XgrammarLogitsProcessor ,
1415)
1516import asyncio
1617from loguru import logger
@@ -30,6 +31,7 @@ class HFBackend(ModelBackend):
3031 def __init__ (self , tokenizer , model : torch .nn .Module ) -> None :
3132 self .model = model
3233 self .tokenizer = tokenizer
34+ self .xgrammar_processor = XgrammarLogitsProcessor (tokenizer )
3335 self .lora_requests = []
3436 lora = os .getenv ("lora" , None )
3537 if lora :
@@ -79,6 +81,28 @@ async def stream_chat(self, params: Dict[str, Any]):
7981 skip_prompt = True ,
8082 decode_kwargsl = {"skip_special_tokens" : True },
8183 )
84+ # TODO
85+ # ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
86+ response_format = params ["response_format" ]
87+ if response_format is not None :
88+ if response_format ["type" ] == "json_object" :
89+ xgrammar_processor = (
90+ self .xgrammar_processor .get_json_grammar_processor ()
91+ )
92+ logits_processor .append (xgrammar_processor )
93+
94+ elif response_format ["type" ] == "json_schema" :
95+ json_schema = response_format ["json_schema" ]
96+ assert json_schema is not None
97+ guided_json = json_schema ["schema" ]
98+ xgrammar_processor = self .xgrammar_processor .get_json_schema_processor (
99+ schema = json .dumps (guided_json )
100+ )
101+ logits_processor .append (xgrammar_processor )
102+ elif response_format ["type" ] == "text" :
103+ pass
104+
105+ # ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
82106 generation_kwargs = dict (
83107 input_ids = input_ids .to (self .model .device ),
84108 streamer = streamer ,
0 commit comments