Skip to content

Commit 8cdcce9

Browse files
authored
[None][fix] Add streaming support to no </think> for nemotron model (NVIDIA#12176)
Signed-off-by: Joyjit Daw <1127155+tijyojwad@users.noreply.github.com>
1 parent 3380889 commit 8cdcce9

File tree

4 files changed

+159
-5
lines changed

4 files changed

+159
-5
lines changed

tensorrt_llm/llmapi/reasoning_parser.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def parse(self, text: str) -> ReasoningParserResult:
6969
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
7070
raise NotImplementedError
7171

72+
def finish(self) -> ReasoningParserResult:
73+
"""Called when the stream ends. Subclasses may override to flush
74+
buffered state or reclassify accumulated content. The default
75+
implementation returns an empty result."""
76+
return ReasoningParserResult()
77+
7278

7379
@register_reasoning_parser("deepseek-r1", reasoning_at_start=True)
7480
@register_reasoning_parser("qwen3")
@@ -185,6 +191,12 @@ def __init__(self,
185191
"force_nonempty_content", False) is True
186192
super().__init__(reasoning_at_start=reasoning_at_start,
187193
chat_template_kwargs=chat_template_kwargs)
194+
# Workaround: the model sometimes does not send closing think tags
195+
# which affects downstream applications. This is addressed by
196+
# optionally accumulating reasoning tokens and returning them as
197+
# content at the end of streaming.
198+
self._accumulated_reasoning = ""
199+
self._found_closing_tag = False
188200

189201
def _maybe_swap_content(
190202
self, result: ReasoningParserResult) -> ReasoningParserResult:
@@ -195,5 +207,50 @@ def _maybe_swap_content(
195207
reasoning_content="")
196208
return result
197209

210+
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
211+
"""Wraps the parent parse_delta to track accumulated reasoning when
212+
force_nonempty_content is set. When the closing tag is found
213+
(in_reasoning transitions from True to False), the accumulation
214+
is cleared to free memory."""
215+
was_in_reasoning = self.in_reasoning
216+
result = super().parse_delta(delta_text)
217+
if self._force_nonempty_content:
218+
if result.reasoning_content:
219+
self._accumulated_reasoning += result.reasoning_content
220+
if was_in_reasoning and not self.in_reasoning:
221+
self._found_closing_tag = True
222+
self._accumulated_reasoning = ""
223+
return result
224+
225+
def finish(self) -> ReasoningParserResult:
226+
"""Called when the stream ends.
227+
228+
If no closing think tag was found and force_nonempty_content is
229+
set, returns the full accumulated reasoning as content so the
230+
response is never empty. If no closing tag was found and
231+
force_nonempty_content is not set, returns any remaining buffer
232+
as reasoning_content since we are still in reasoning mode.
233+
234+
If the closing tag was already found (or reasoning was never
235+
entered), flushes any remaining buffer as content."""
236+
if self.in_reasoning and not self._found_closing_tag:
237+
remaining = self._buffer
238+
self._buffer = ""
239+
if self._force_nonempty_content:
240+
all_content = self._accumulated_reasoning + remaining
241+
self._accumulated_reasoning = ""
242+
self.in_reasoning = False
243+
return ReasoningParserResult(content=all_content)
244+
self._accumulated_reasoning = ""
245+
self.in_reasoning = False
246+
if remaining:
247+
return ReasoningParserResult(reasoning_content=remaining)
248+
return ReasoningParserResult()
249+
remaining = self._buffer
250+
self._buffer = ""
251+
if remaining:
252+
return ReasoningParserResult(content=remaining)
253+
return ReasoningParserResult()
254+
198255
def parse(self, text: str) -> ReasoningParserResult:
199256
return self._maybe_swap_content(super().parse(text))

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from ..executor.result import Logprob, TokenLogprobs
1313
from ..llmapi import SamplingParams
1414
from ..llmapi.reasoning_parser import (BaseReasoningParser,
15-
ReasoningParserFactory)
15+
ReasoningParserFactory,
16+
ReasoningParserResult)
1617
from ..llmapi.tokenizer import TransformersTokenizer
1718
# yapf: disable
1819
from .chat_utils import make_tool_call_id
@@ -111,8 +112,11 @@ def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
111112
return chat_logprobs
112113

113114

114-
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
115-
streaming: bool) -> Tuple[str, str]:
115+
def apply_reasoning_parser(args: ChatPostprocArgs,
116+
output_index: int,
117+
text: str,
118+
streaming: bool,
119+
finished: bool = False) -> Tuple[str, str]:
116120
reasoning_parser = None
117121
if args.reasoning_parser is not None:
118122
if output_index not in args.reasoning_parser_dict:
@@ -127,6 +131,13 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
127131
result = reasoning_parser.parse(text)
128132
else:
129133
result = reasoning_parser.parse_delta(text)
134+
if finished:
135+
finish_result = reasoning_parser.finish()
136+
result = ReasoningParserResult(
137+
content=result.content + finish_result.content,
138+
reasoning_content=result.reasoning_content +
139+
finish_result.reasoning_content,
140+
)
130141
content, reasoning_content = result.content, result.reasoning_content
131142
else:
132143
content, reasoning_content = text, ""
@@ -214,7 +225,11 @@ def yield_first_chat(num_tokens: int,
214225
delta_text = output.text_diff
215226

216227
delta_text, reasoning_delta_text = apply_reasoning_parser(
217-
args, i, delta_text, True)
228+
args,
229+
i,
230+
delta_text,
231+
True,
232+
finished=(output.finish_reason is not None))
218233

219234
if args.tool_choice and type(
220235
args.tool_choice) is ChatCompletionNamedToolChoiceParam:

tensorrt_llm/serve/responses_utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from tensorrt_llm.llmapi import SamplingParams
4747
from tensorrt_llm.llmapi.llm import RequestOutput
4848
from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser,
49-
ReasoningParserFactory)
49+
ReasoningParserFactory,
50+
ReasoningParserResult)
5051
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
5152
from tensorrt_llm.logger import logger
5253
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
@@ -927,6 +928,7 @@ def _apply_reasoning_parser(
927928
text: str,
928929
streaming: bool,
929930
reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None,
931+
finished: bool = False,
930932
) -> Tuple[str, str]:
931933
reasoning_parser: Optional[BaseReasoningParser] = None
932934
if reasoning_parser_id is not None:
@@ -946,6 +948,13 @@ def _apply_reasoning_parser(
946948
result = reasoning_parser.parse(text)
947949
else:
948950
result = reasoning_parser.parse_delta(text)
951+
if finished:
952+
finish_result = reasoning_parser.finish()
953+
result = ReasoningParserResult(
954+
content=result.content + finish_result.content,
955+
reasoning_content=result.reasoning_content +
956+
finish_result.reasoning_content,
957+
)
949958
content, reasoning_content = result.content, result.reasoning_content
950959
else:
951960
content, reasoning_content = text, ""
@@ -1490,6 +1499,14 @@ def _should_send_done_events(
14901499
should_send_reasoning_done = True
14911500
reasoning_content = full_reasoning
14921501

1502+
# No closing tag: reasoning was streamed but re-parse shows everything as
1503+
# content (no </think> found). Close the reasoning section so the text
1504+
# section can be properly opened and closed.
1505+
if not full_reasoning and full_text and finished_generation:
1506+
if streaming_events_helper and streaming_events_helper.is_reasoning_sent:
1507+
should_send_reasoning_done = True
1508+
reasoning_content = full_text
1509+
14931510
return should_send_reasoning_done, should_send_text_done, reasoning_content, text_content
14941511

14951512

@@ -1525,6 +1542,7 @@ def check_parser(parser_id: Optional[str],
15251542
text=delta_text,
15261543
streaming=True,
15271544
reasoning_parser_dict=reasoning_parser_dict,
1545+
finished=finished_generation,
15281546
)
15291547

15301548
if delta_text:
@@ -1595,6 +1613,37 @@ def check_parser(parser_id: Optional[str],
15951613
streaming_events_helper.is_output_item_added_sent = False
15961614
streaming_events_helper.is_text_sent = False
15971615

1616+
# Handle no-closing-tag case: reasoning was streamed but finish() moved
1617+
# all accumulated reasoning to content. Emit the full text section
1618+
# lifecycle (added → delta → done) since the reasoning section was just
1619+
# closed and generation is finished.
1620+
if (finished_generation and delta_text and should_send_reasoning_done
1621+
and not should_send_text_done):
1622+
streaming_events_helper.is_text_sent = True
1623+
yield from streaming_events_helper.get_message_output_added_events()
1624+
yield streaming_events_helper.get_text_delta_event(delta_text, [])
1625+
text_content_obj = ResponseOutputText(
1626+
text=delta_text,
1627+
annotations=[],
1628+
type="output_text",
1629+
logprobs=None,
1630+
)
1631+
text_item = ResponseOutputMessage(
1632+
id=streaming_events_helper.item_id,
1633+
content=[text_content_obj],
1634+
role="assistant",
1635+
status="completed",
1636+
type="message",
1637+
)
1638+
yield streaming_events_helper.get_text_done_event(delta_text, [])
1639+
yield streaming_events_helper.get_content_part_done_event(
1640+
text_content_obj)
1641+
yield streaming_events_helper.get_output_item_done_event(text_item)
1642+
streaming_events_helper.output_index_increment()
1643+
streaming_events_helper.is_output_item_added_sent = False
1644+
streaming_events_helper.is_text_sent = False
1645+
delta_text = ""
1646+
15981647
# Send delta events for ongoing content
15991648
if delta_text:
16001649
if delta_text.strip():

tests/unittest/llmapi/test_reasoning_parser.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,36 @@ def test_nano_v3_reasoning_parser_stream(delta_texts: list, content: list,
160160
print(f"delta_text: {delta_text}, result: {result}")
161161
assert result.content == content[i]
162162
assert result.reasoning_content == reasoning_context[i]
163+
164+
165+
@pytest.mark.parametrize(("delta_texts", "finish_content", "finish_reasoning",
166+
"chat_template_kwargs"), [
167+
(["a", "b"], "", "", None),
168+
([R1_END, "a", "b"], "", "", None),
169+
(["a", R1_END, "b"], "", "", None),
170+
(["a", "b"], "", "", {
171+
"enable_thinking": False
172+
}),
173+
([f"{R1_START}a", "b"], "", "", {
174+
"enable_thinking": False
175+
}),
176+
(["a", "b"], "", "", {
177+
"force_nonempty_content": False
178+
}),
179+
(["a", "b"], "ab", "", {
180+
"force_nonempty_content": True
181+
}),
182+
([R1_END, "a", "b"], "", "", {
183+
"force_nonempty_content": True
184+
}),
185+
])
186+
def test_nano_v3_reasoning_parser_finish(delta_texts: list, finish_content: str,
187+
finish_reasoning: str,
188+
chat_template_kwargs: dict):
189+
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
190+
"nano-v3", chat_template_kwargs)
191+
for delta_text in delta_texts:
192+
reasoning_parser.parse_delta(delta_text)
193+
result = reasoning_parser.finish()
194+
assert result.content == finish_content
195+
assert result.reasoning_content == finish_reasoning

0 commit comments

Comments
 (0)