Skip to content

Commit 3ce165a

Browse files
committed
hf 后端支持guided_decoding
1 parent 1e180df commit 3ce165a

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

gpt_server/model_backend/hf_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
InvalidScoreLogitsProcessor,
1212
StoppingCriteriaList,
1313
StopAtSpecificTokenCriteria,
14+
XgrammarLogitsProcessor,
1415
)
1516
import asyncio
1617
from loguru import logger
@@ -30,6 +31,7 @@ class HFBackend(ModelBackend):
3031
def __init__(self, tokenizer, model: torch.nn.Module) -> None:
3132
self.model = model
3233
self.tokenizer = tokenizer
34+
self.xgrammar_processor = XgrammarLogitsProcessor(tokenizer)
3335
self.lora_requests = []
3436
lora = os.getenv("lora", None)
3537
if lora:
@@ -79,6 +81,28 @@ async def stream_chat(self, params: Dict[str, Any]):
7981
skip_prompt=True,
8082
decode_kwargsl={"skip_special_tokens": True},
8183
)
84+
# TODO
85+
# ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
86+
response_format = params["response_format"]
87+
if response_format is not None:
88+
if response_format["type"] == "json_object":
89+
xgrammar_processor = (
90+
self.xgrammar_processor.get_json_grammar_processor()
91+
)
92+
logits_processor.append(xgrammar_processor)
93+
94+
elif response_format["type"] == "json_schema":
95+
json_schema = response_format["json_schema"]
96+
assert json_schema is not None
97+
guided_json = json_schema["schema"]
98+
xgrammar_processor = self.xgrammar_processor.get_json_schema_processor(
99+
schema=json.dumps(guided_json)
100+
)
101+
logits_processor.append(xgrammar_processor)
102+
elif response_format["type"] == "text":
103+
pass
104+
105+
# ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
82106
generation_kwargs = dict(
83107
input_ids=input_ids.to(self.model.device),
84108
streamer=streamer,

gpt_server/model_backend/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
1818
self.grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
1919
# -----------
2020

21-
def get_grammar_compiler(self, schema: Union[str, Type[BaseModel]]):
21+
def get_json_grammar_processor(self):
22+
compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
23+
self.xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
24+
return self.xgr_logits_processor
25+
26+
def get_json_schema_processor(self, schema: Union[str, Type[BaseModel]]):
2227
compiled_grammar = self.grammar_compiler.compile_json_schema(schema)
2328
self.xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
2429
return self.xgr_logits_processor

0 commit comments

Comments
 (0)