diff --git a/tests/reasoning/test_exaone4_reasoning_parser.py b/tests/reasoning/test_exaone4_reasoning_parser.py new file mode 100644 index 000000000000..8566c7835cc5 --- /dev/null +++ b/tests/reasoning/test_exaone4_reasoning_parser.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.entrypoints.openai.protocol import ChatCompletionRequest + +parser_name = "exaone4" +start_token = "" +end_token = "" + +REASONING_MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-1.2B" + + +@pytest.fixture(scope="module") +def exaone4_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +SIMPLE_REASONING = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_REASONING = { + "output": "This is content", + "reasoning_content": None, + "content": "This is content", + "is_reasoning_end": False, + "skip_extract_content": True, +} +NO_REASONING_STREAMING = { + "output": "This is a normal section", + "reasoning_content": None, + "content": "This is a normal section", + "is_reasoning_end": False, + "skip_extract_content": True, +} +NO_REASONING_STREAMING_WITH_THINK = { + "output": "This is a normal section", + "reasoning_content": "This is a normal section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +REASONING_WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING_WITH_THINK = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +THINK_NO_END = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +EMPTY = { + "output": "", + "reasoning_content": None, + "content": "", + "is_reasoning_end": False, +} +EMPTY_STREAMING = { + "output": "", + "reasoning_content": None, + "content": None, + "is_reasoning_end": False, +} +NEW_LINE = { + "output": "\nThis is a reasoning section\nThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} +NEW_LINE_STREAMING = { + "output": "\nThis is a reasoning section\nThis is the rest", + "reasoning_content": "\nThis is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + False, + id="simple_reasoning", + ), + pytest.param( + True, + SIMPLE_REASONING, + True, + id="simple_reasoning_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + False, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + True, + id="complete_reasoning_streaming", + ), + pytest.param( + False, + NO_REASONING, + False, + id="no_reasoning_token", + ), + pytest.param( + True, + NO_REASONING_STREAMING, + False, + id="no_reasoning_token_streaming", + ), + pytest.param( + True, + NO_REASONING_STREAMING_WITH_THINK, + True, + id="no_reasoning_token_streaming_with_think", + ), + pytest.param( + False, + MULTIPLE_LINES, + False, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + True, + id="multiple_lines_streaming", + ), + pytest.param( + True, + SHORTEST_REASONING, + False, + id="shortest", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + True, + id="shortest_streaming", + ), + pytest.param( + False, + REASONING_WITH_THINK, + False, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + True, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + False, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + True, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + False, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + True, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + False, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + True, + id="shortest_with_think_streaming", + ), + pytest.param( + False, + THINK_NO_END, + False, + id="think_no_end", + ), + pytest.param( + True, + THINK_NO_END, + True, + id="think_no_end_streaming", + ), + pytest.param( + False, + EMPTY, + False, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + True, + id="empty_streaming", + ), + pytest.param( + False, + NEW_LINE, + False, + id="new_line", + ), + pytest.param( + True, + NEW_LINE_STREAMING, + True, + id="new_line_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict, enable_thinking", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + enable_thinking: bool, + exaone4_tokenizer, +): + output = exaone4_tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + exaone4_tokenizer.convert_tokens_to_string([token]) + for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(exaone4_tokenizer) + + dummy_request = ChatCompletionRequest( + messages=[], + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + request=dummy_request if enable_thinking else None, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = exaone4_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + + # NOTE: In case of `no_reasoning_token`s, We omit the extract_content test. + # By default, EXAONE 4.0 parser assumes the content is the whole output + # if there is no '' or '', and `enable_thinking=False`. + # `extract_content_ids()` cannot get `enable_thinking` from the request, + # and it is only used for removing the reasoning content from the output + # on vllm.entrypoints.openai.serving_chat.py. + # So we let `extract_content_ids()` as is (assume the output is reasoning content + # with the condition: no '' or '' and `enable_thinking=False`). + if param_dict.get("skip_extract_content", False): + return + + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_ids) + assert content == exaone4_tokenizer.convert_tokens_to_ids( + exaone4_tokenizer.tokenize(param_dict["content"])) + else: + content = parser.extract_content_ids(output) + assert content == [] diff --git a/tests/reasoning/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py index 38cab73a45f2..d94362f3c405 100644 --- a/tests/reasoning/test_granite_reasoning_parser.py +++ b/tests/reasoning/test_granite_reasoning_parser.py @@ -336,6 +336,7 @@ def test_streaming_subcases(param_dict): previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, + request=None, ) # Streaming currently expects at least one of reasoning content / content, # so the response should return None in that case. diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 9af5fa5addbc..38d4d2eb8b58 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -115,6 +115,7 @@ def run_reasoning_extraction_streaming( previous_tokens, current_tokens, token_delta, + request, ) if delta_message is not None: reconstructor.append_delta(delta_message) @@ -147,6 +148,7 @@ def run_reasoning_extraction_streaming_mistral( previous_tokens, current_tokens, token_delta, + request, ) if delta_message is not None: reconstructor.append_delta(delta_message) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b4231c6d10c4..b1e1aadb9644 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -696,6 +696,7 @@ async def chat_completion_stream_generator( previous_token_ids, current_token_ids, output.token_ids, + request, )) # When encountering think end id in delta_token_ids # or think end id in prompt_token_ids @@ -781,6 +782,7 @@ async def chat_completion_stream_generator( previous_token_ids, current_token_ids, output.token_ids, + request, )) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, @@ -858,6 +860,7 @@ async def chat_completion_stream_generator( previous_token_ids, current_token_ids, output.token_ids, + request, )) # handle streaming just a content delta else: diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index b987adeb6428..98b3c4abc646 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -10,6 +10,7 @@ from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser from .step3_reasoning_parser import Step3ReasoningParser +from .exaone4_reasoning_parser import Exaone4ReasoningParser __all__ = [ "ReasoningParser", @@ -22,4 +23,5 @@ "MistralReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", + "Exaone4ReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 4f4522d726e8..5eede425f36e 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -105,6 +105,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting reasoning diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f1..163eb8c6138f 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -64,6 +64,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message. diff --git a/vllm/reasoning/exaone4_reasoning_parser.py b/vllm/reasoning/exaone4_reasoning_parser.py new file mode 100644 index 000000000000..cf9a3094fa94 --- /dev/null +++ b/vllm/reasoning/exaone4_reasoning_parser.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser) + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("exaone4") +class Exaone4ReasoningParser(DeepSeekR1ReasoningParser): + """ + Reasoning parser for EXAONE 4.0 model. + + The EXAONE 4.0 model uses ... tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "EXAONE 4.0 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.start_token_id, self.end_token_id + ]): + return None + + # Check if is present in previous or delta. + # Keep compatibility with models that don't generate tokens. + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[start_index + + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # No in previous or delta, also need to check for . + # Because the model may have generated without + # Ref https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B/blob/main/chat_template.jinja#L139-L146 + if self.end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # in previous, thinking content ends + return DeltaMessage(content=delta_text) + else: + # no in previous or delta, + # use `enable_thinking` to determine if the model is thinking + if request is not None and \ + request.chat_template_kwargs is not None and \ + request.chat_template_kwargs.get("enable_thinking"): + return DeltaMessage(reasoning_content=delta_text) + else: + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + + # Check if the start token is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + + # EXAONE 4.0 doesn't generate tokens. + # Ref https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B/blob/main/chat_template.jinja#L139-L146 + if self.end_token not in model_output: + if model_output_parts[1]: + return model_output, None + return None, model_output + else: + reasoning_content, _, content = model_output.partition( + self.end_token) + + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py index 460e38d2d396..76e566629054 100644 --- a/vllm/reasoning/glm4_moe_reasoning_parser.py +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -64,6 +64,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message. diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 05a72ac23bf2..38bd72811625 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -51,6 +51,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: raise RuntimeError( "GptOss model uses harmony to extract reasoning content. This " diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 5820001b918f..a1aeed8b8fe9 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -83,6 +83,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """Extract the reasoning content / content emitted by granite models; If the sequence doesn't match what we expect, i.e., the model generates diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index b2452b95c1c6..de4908dffbbe 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -152,6 +152,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """Extract content using token ID sequence state machine""" # Define sequences diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 61bafc724c17..7d167639a8db 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -64,6 +64,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message.