Skip to content

Commit a47e6ff

Browse files
heheda12345LiuXiaoxuanPKUsimon-moWoosukKwonhongxiayang
authored
[GptOss] Add GptOss reasoning parser to support structure output (#22322)
Signed-off-by: Chen Zhang <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: Minseok Lee <[email protected]> Co-authored-by: Yongye Zhu <[email protected]>
1 parent 98a3a81 commit a47e6ff

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

vllm/model_executor/models/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,13 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
247247
config.max_model_len)
248248

249249

250-
class GptOssConfig(VerifyAndUpdateConfig):
250+
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
251251

252252
@staticmethod
253253
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
254254
decoding_config = vllm_config.decoding_config
255255
if decoding_config.reasoning_backend == "":
256-
decoding_config.reasoning_backend = "openai"
256+
decoding_config.reasoning_backend = "GptOss"
257257

258258
# Increase the max capture size from 512 to 1024 for performance.
259259
# NOTE(woosuk): This will increase the number of CUDA graphs
@@ -373,5 +373,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
373373
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
374374
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
375375
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
376-
"GptOssForCausalLM": GptOssConfig,
376+
"GptOssForCausalLM": GptOssForCausalLMConfig,
377377
}

vllm/reasoning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
55
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
66
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
7+
from .gptoss_reasoning_parser import GptOssReasoningParser
78
from .granite_reasoning_parser import GraniteReasoningParser
89
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
910
from .mistral_reasoning_parser import MistralReasoningParser
@@ -20,4 +21,5 @@
2021
"Glm4MoeModelReasoningParser",
2122
"MistralReasoningParser",
2223
"Step3ReasoningParser",
24+
"GptOssReasoningParser",
2325
]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Sequence
5+
from typing import Optional, Union
6+
7+
from transformers import PreTrainedTokenizerBase
8+
9+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
10+
DeltaMessage)
11+
from vllm.logger import init_logger
12+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
13+
14+
logger = init_logger(__name__)
15+
16+
17+
@ReasoningParserManager.register_module("GptOss")
18+
class GptOssReasoningParser(ReasoningParser):
19+
"""
20+
Reasoning parser for GptOss model.
21+
22+
The GptOss model uses harmony to extract reasoning content and this parser
23+
is only used for detecting the end of the reasoning content.
24+
"""
25+
26+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
27+
super().__init__(tokenizer)
28+
self.reasoning_end_token_ids = self.model_tokenizer.encode(
29+
"<|start|>assistant<|channel|>final<|message|>")
30+
31+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
32+
end_token_ids = self.reasoning_end_token_ids
33+
assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
34+
# Check if the end sequence is present in the input_ids.
35+
# We search from the end of input_ids to find the last match.
36+
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
37+
if input_ids[i:i + len(end_token_ids)] == end_token_ids:
38+
return True
39+
return False
40+
41+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
42+
raise RuntimeError(
43+
"GptOss model uses harmony to extract reasoning content. This "
44+
"function should not be called.")
45+
46+
def extract_reasoning_content_streaming(
47+
self,
48+
previous_text: str,
49+
current_text: str,
50+
delta_text: str,
51+
previous_token_ids: Sequence[int],
52+
current_token_ids: Sequence[int],
53+
delta_token_ids: Sequence[int],
54+
) -> Union[DeltaMessage, None]:
55+
raise RuntimeError(
56+
"GptOss model uses harmony to extract reasoning content. This "
57+
"function should not be called.")
58+
59+
def extract_reasoning_content(
60+
self, model_output: str, request: ChatCompletionRequest
61+
) -> tuple[Optional[str], Optional[str]]:
62+
raise RuntimeError(
63+
"GptOss model uses harmony to extract reasoning content. This "
64+
"function should not be called.")

0 commit comments

Comments
 (0)