Skip to content

Commit f6b8a6d

Browse files
authored
feat: support reasoning_content in OpenAI Chat Completions (#20786)
feat: support reasoning_content in OpenAI Chat Completions streaming Extract the reasoning_content field from streaming and non-streaming Chat Completion responses, surfacing it as ThinkingBlock and thinking_delta so agents can consume chain-of-thought output from any OpenAI-compatible provider. Also skip ThinkingBlock when round-tripping messages back to the Chat Completions API. Closes #19124
1 parent d978208 commit f6b8a6d

File tree

3 files changed

+283
-12
lines changed

3 files changed

+283
-12
lines changed

llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
CompletionResponseGen,
4545
LLMMetadata,
4646
MessageRole,
47+
ThinkingBlock,
4748
ToolCallBlock,
4849
TextBlock,
4950
)
@@ -540,6 +541,7 @@ def _stream_chat(
540541

541542
def gen() -> ChatResponseGen:
542543
content = ""
544+
reasoning_content = ""
543545
tool_calls: List[ChoiceDeltaToolCall] = []
544546

545547
is_function = False
@@ -565,13 +567,24 @@ def gen() -> ChatResponseGen:
565567
role = delta.role or MessageRole.ASSISTANT
566568
content_delta = delta.content or ""
567569
content += content_delta
570+
571+
# Extract reasoning_content for chain-of-thought streaming.
572+
# Many OpenAI-compatible providers surface this extra field.
573+
raw_reasoning = getattr(delta, "reasoning_content", None)
574+
reasoning_delta = (
575+
raw_reasoning if isinstance(raw_reasoning, str) else ""
576+
)
577+
reasoning_content += reasoning_delta
578+
579+
if reasoning_content:
580+
blocks.append(ThinkingBlock(content=reasoning_content))
568581
blocks.append(TextBlock(text=content))
569582

570-
additional_kwargs = {}
583+
message_additional_kwargs = {}
571584
if is_function:
572585
tool_calls = update_tool_calls(tool_calls, delta.tool_calls)
573586
if tool_calls:
574-
additional_kwargs["tool_calls"] = tool_calls
587+
message_additional_kwargs["tool_calls"] = tool_calls
575588
for tool_call in tool_calls:
576589
if tool_call.function:
577590
blocks.append(
@@ -582,15 +595,21 @@ def gen() -> ChatResponseGen:
582595
)
583596
)
584597

598+
# thinking_delta goes in ChatResponse.additional_kwargs
599+
# (same pattern as Ollama) so agents can read it
600+
response_additional_kwargs = self._get_response_token_counts(response)
601+
if reasoning_delta:
602+
response_additional_kwargs["thinking_delta"] = reasoning_delta
603+
585604
yield ChatResponse(
586605
message=ChatMessage(
587606
role=role,
588607
blocks=blocks,
589-
additional_kwargs=additional_kwargs,
608+
additional_kwargs=message_additional_kwargs,
590609
),
591610
delta=content_delta,
592611
raw=response,
593-
additional_kwargs=self._get_response_token_counts(response),
612+
additional_kwargs=response_additional_kwargs,
594613
)
595614

596615
return gen()
@@ -807,6 +826,7 @@ async def _astream_chat(
807826

808827
async def gen() -> ChatResponseAsyncGen:
809828
content = ""
829+
reasoning_content = ""
810830
tool_calls: List[ChoiceDeltaToolCall] = []
811831

812832
is_function = False
@@ -843,13 +863,24 @@ async def gen() -> ChatResponseAsyncGen:
843863
role = delta.role or MessageRole.ASSISTANT
844864
content_delta = delta.content or ""
845865
content += content_delta
866+
867+
# Extract reasoning_content for chain-of-thought streaming.
868+
# Many OpenAI-compatible providers surface this extra field.
869+
raw_reasoning = getattr(delta, "reasoning_content", None)
870+
reasoning_delta = (
871+
raw_reasoning if isinstance(raw_reasoning, str) else ""
872+
)
873+
reasoning_content += reasoning_delta
874+
875+
if reasoning_content:
876+
blocks.append(ThinkingBlock(content=reasoning_content))
846877
blocks.append(TextBlock(text=content))
847878

848-
additional_kwargs = {}
879+
message_additional_kwargs = {}
849880
if is_function:
850881
tool_calls = update_tool_calls(tool_calls, delta.tool_calls)
851882
if tool_calls:
852-
additional_kwargs["tool_calls"] = tool_calls
883+
message_additional_kwargs["tool_calls"] = tool_calls
853884
for tool_call in tool_calls:
854885
if tool_call.function:
855886
blocks.append(
@@ -860,15 +891,21 @@ async def gen() -> ChatResponseAsyncGen:
860891
)
861892
)
862893

894+
# thinking_delta goes in ChatResponse.additional_kwargs
895+
# (same pattern as Ollama) so agents can read it
896+
response_additional_kwargs = self._get_response_token_counts(response)
897+
if reasoning_delta:
898+
response_additional_kwargs["thinking_delta"] = reasoning_delta
899+
863900
yield ChatResponse(
864901
message=ChatMessage(
865902
role=role,
866903
blocks=blocks,
867-
additional_kwargs=additional_kwargs,
904+
additional_kwargs=message_additional_kwargs,
868905
),
869906
delta=content_delta,
870907
raw=response,
871-
additional_kwargs=self._get_response_token_counts(response),
908+
additional_kwargs=response_additional_kwargs,
872909
)
873910

874911
return gen()

llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ def to_openai_message_dict(
423423
},
424424
}
425425
)
426+
elif isinstance(block, ThinkingBlock):
427+
# ThinkingBlock is not supported in the Chat Completions API input;
428+
# skip it when converting messages back (round-tripping).
429+
continue
426430
elif isinstance(block, ToolCallBlock):
427431
try:
428432
function_dict = {
@@ -736,11 +740,17 @@ def from_openai_message(
736740
) -> ChatMessage:
737741
"""Convert openai message dict to generic message."""
738742
role = openai_message.role
743+
blocks: List[ContentBlock] = []
744+
745+
# Extract reasoning_content if present (used by many OpenAI-compatible
746+
# providers for chain-of-thought responses)
747+
reasoning_content = getattr(openai_message, "reasoning_content", None)
748+
if isinstance(reasoning_content, str) and reasoning_content:
749+
blocks.append(ThinkingBlock(content=reasoning_content))
750+
739751
# NOTE: Azure OpenAI returns function calling messages without a content key
740752
if "text" in modalities and openai_message.content:
741-
blocks: List[ContentBlock] = [TextBlock(text=openai_message.content or "")]
742-
else:
743-
blocks: List[ContentBlock] = []
753+
blocks.append(TextBlock(text=openai_message.content or ""))
744754

745755
additional_kwargs: Dict[str, Any] = {}
746756
if openai_message.tool_calls:

llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai.py

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest.mock import AsyncMock, MagicMock, patch
44

55
import pytest
6-
from llama_index.core.base.llms.types import ChatMessage
6+
from llama_index.core.base.llms.types import ChatMessage, ThinkingBlock, TextBlock
77
from llama_index.llms.openai import OpenAI
88
from llama_index.llms.openai.utils import O1_MODELS
99

@@ -601,3 +601,227 @@ def test_reasoning_effort_none_default():
601601
llm = OpenAI(model=model_name, api_key="test-key")
602602
kwargs = llm._get_model_kwargs()
603603
assert "reasoning_effort" not in kwargs
604+
605+
606+
# ===== reasoning_content tests (OpenAI-compatible providers) =====
607+
608+
609+
def _make_chunk(
610+
delta_kwargs: dict, finish_reason: Optional[str] = None
611+
) -> ChatCompletionChunk:
612+
"""Helper to create a single ChatCompletionChunk."""
613+
extra = delta_kwargs.pop("__extra__", None)
614+
chunk = ChatCompletionChunk(
615+
id="chatcmpl-reasoning",
616+
object="chat.completion.chunk",
617+
created=1700000000,
618+
model="qwen3-thinking",
619+
choices=[
620+
ChunkChoice(
621+
delta=ChoiceDelta(**delta_kwargs),
622+
finish_reason=finish_reason,
623+
index=0,
624+
)
625+
],
626+
)
627+
if extra:
628+
chunk.choices[0].delta.__pydantic_extra__ = extra
629+
return chunk
630+
631+
632+
def _make_reasoning_stream_chunks() -> list[ChatCompletionChunk]:
633+
"""Simulate an OpenAI-compatible API streaming reasoning_content then content."""
634+
return [
635+
_make_chunk({"role": "assistant"}),
636+
_make_chunk(
637+
{"content": None, "__extra__": {"reasoning_content": "Let me think"}}
638+
),
639+
_make_chunk(
640+
{"content": None, "__extra__": {"reasoning_content": " about this."}}
641+
),
642+
_make_chunk({"content": "The answer"}),
643+
_make_chunk({"content": " is 42."}),
644+
_make_chunk({}, finish_reason="stop"),
645+
]
646+
647+
648+
@patch("llama_index.llms.openai.base.SyncOpenAI")
649+
def test_stream_chat_reasoning_content(MockSyncOpenAI: MagicMock) -> None:
650+
"""Test that reasoning_content from streaming is captured as ThinkingBlock and thinking_delta."""
651+
with CachedOpenAIApiKeys(set_fake_key=True):
652+
mock_instance = MockSyncOpenAI.return_value
653+
mock_instance.chat.completions.create.return_value = iter(
654+
_make_reasoning_stream_chunks()
655+
)
656+
657+
llm = OpenAI(model="gpt-4o", api_key="test-key")
658+
responses = list(llm.stream_chat([ChatMessage(role="user", content="test")]))
659+
660+
final = responses[-1]
661+
thinking_blocks = [
662+
b for b in final.message.blocks if isinstance(b, ThinkingBlock)
663+
]
664+
text_blocks = [b for b in final.message.blocks if isinstance(b, TextBlock)]
665+
666+
assert len(thinking_blocks) == 1
667+
assert thinking_blocks[0].content == "Let me think about this."
668+
assert len(text_blocks) == 1
669+
assert text_blocks[0].text == "The answer is 42."
670+
671+
# Exactly 2 chunks carry thinking_delta (the two reasoning chunks)
672+
reasoning_chunks = [
673+
r for r in responses if r.additional_kwargs.get("thinking_delta")
674+
]
675+
assert len(reasoning_chunks) == 2
676+
assert reasoning_chunks[0].additional_kwargs["thinking_delta"] == "Let me think"
677+
assert reasoning_chunks[1].additional_kwargs["thinking_delta"] == " about this."
678+
679+
680+
@pytest.mark.asyncio()
681+
@patch("llama_index.llms.openai.base.AsyncOpenAI")
682+
async def test_astream_chat_reasoning_content(MockAsyncOpenAI: MagicMock) -> None:
683+
"""Test that reasoning_content from async streaming is captured as ThinkingBlock."""
684+
mock_instance = MockAsyncOpenAI.return_value
685+
686+
async def mock_async_stream(*args: Any, **kwargs: Any) -> AsyncGenerator:
687+
for chunk in _make_reasoning_stream_chunks():
688+
yield chunk
689+
690+
create_fn = AsyncMock()
691+
create_fn.return_value = mock_async_stream()
692+
mock_instance.chat.completions.create = create_fn
693+
694+
llm = OpenAI(model="gpt-4o", api_key="test-key")
695+
response_gen = await llm.astream_chat([ChatMessage(role="user", content="test")])
696+
responses = [r async for r in response_gen]
697+
698+
final = responses[-1]
699+
thinking_blocks = [b for b in final.message.blocks if isinstance(b, ThinkingBlock)]
700+
text_blocks = [b for b in final.message.blocks if isinstance(b, TextBlock)]
701+
702+
assert len(thinking_blocks) == 1
703+
assert thinking_blocks[0].content == "Let me think about this."
704+
assert len(text_blocks) == 1
705+
assert text_blocks[0].text == "The answer is 42."
706+
707+
# Verify thinking_delta on async path too
708+
reasoning_chunks = [
709+
r for r in responses if r.additional_kwargs.get("thinking_delta")
710+
]
711+
assert len(reasoning_chunks) == 2
712+
713+
714+
@patch("llama_index.llms.openai.base.SyncOpenAI")
715+
def test_chat_reasoning_content_non_streaming(MockSyncOpenAI: MagicMock) -> None:
716+
"""Test that reasoning_content in non-streaming responses is captured as ThinkingBlock."""
717+
with CachedOpenAIApiKeys(set_fake_key=True):
718+
response = ChatCompletion(
719+
id="chatcmpl-reasoning",
720+
object="chat.completion",
721+
created=1700000000,
722+
model="qwen3-thinking",
723+
choices=[
724+
Choice(
725+
message=ChatCompletionMessage(
726+
role="assistant",
727+
content="The answer is 42.",
728+
),
729+
finish_reason="stop",
730+
index=0,
731+
)
732+
],
733+
)
734+
response.choices[0].message.__pydantic_extra__ = {
735+
"reasoning_content": "Let me think step by step..."
736+
}
737+
738+
mock_instance = MockSyncOpenAI.return_value
739+
mock_instance.chat.completions.create.return_value = response
740+
741+
llm = OpenAI(model="gpt-4o", api_key="test-key")
742+
result = llm.chat([ChatMessage(role="user", content="test")])
743+
744+
thinking_blocks = [
745+
b for b in result.message.blocks if isinstance(b, ThinkingBlock)
746+
]
747+
text_blocks = [b for b in result.message.blocks if isinstance(b, TextBlock)]
748+
749+
assert len(thinking_blocks) == 1
750+
assert thinking_blocks[0].content == "Let me think step by step..."
751+
assert len(text_blocks) == 1
752+
assert text_blocks[0].text == "The answer is 42."
753+
754+
755+
@patch("llama_index.llms.openai.base.SyncOpenAI")
756+
def test_stream_chat_no_reasoning_content(MockSyncOpenAI: MagicMock) -> None:
757+
"""Test that streaming without reasoning_content produces no ThinkingBlock."""
758+
with CachedOpenAIApiKeys(set_fake_key=True):
759+
mock_instance = MockSyncOpenAI.return_value
760+
mock_instance.chat.completions.create.return_value = (
761+
mock_chat_completion_stream_v1()
762+
)
763+
764+
llm = OpenAI(model="gpt-4o", api_key="test-key")
765+
responses = list(llm.stream_chat([ChatMessage(role="user", content="test")]))
766+
767+
final = responses[-1]
768+
thinking_blocks = [
769+
b for b in final.message.blocks if isinstance(b, ThinkingBlock)
770+
]
771+
assert len(thinking_blocks) == 0
772+
assert final.message.content == "\n\n2"
773+
774+
775+
def test_to_openai_message_dict_skips_thinking_block() -> None:
776+
"""Test that ThinkingBlock is skipped when converting messages to OpenAI format."""
777+
from llama_index.llms.openai.utils import to_openai_message_dict
778+
779+
message = ChatMessage(
780+
role="assistant",
781+
blocks=[
782+
ThinkingBlock(content="internal reasoning"),
783+
TextBlock(text="The answer is 42."),
784+
],
785+
)
786+
787+
result = to_openai_message_dict(message)
788+
assert result["role"] == "assistant"
789+
assert result["content"] == "The answer is 42."
790+
791+
792+
def test_from_openai_message_with_reasoning_content() -> None:
793+
"""Test that from_openai_message extracts reasoning_content as ThinkingBlock."""
794+
from llama_index.llms.openai.utils import from_openai_message
795+
796+
openai_msg = ChatCompletionMessage(
797+
role="assistant",
798+
content="The answer is 42.",
799+
)
800+
openai_msg.__pydantic_extra__ = {"reasoning_content": "Let me think..."}
801+
802+
result = from_openai_message(openai_msg, modalities=["text"])
803+
804+
thinking_blocks = [b for b in result.blocks if isinstance(b, ThinkingBlock)]
805+
text_blocks = [b for b in result.blocks if isinstance(b, TextBlock)]
806+
807+
assert len(thinking_blocks) == 1
808+
assert thinking_blocks[0].content == "Let me think..."
809+
assert len(text_blocks) == 1
810+
assert text_blocks[0].text == "The answer is 42."
811+
812+
813+
def test_from_openai_message_without_reasoning_content() -> None:
814+
"""Test that from_openai_message works normally without reasoning_content."""
815+
from llama_index.llms.openai.utils import from_openai_message
816+
817+
openai_msg = ChatCompletionMessage(
818+
role="assistant",
819+
content="Hello!",
820+
)
821+
822+
result = from_openai_message(openai_msg, modalities=["text"])
823+
824+
thinking_blocks = [b for b in result.blocks if isinstance(b, ThinkingBlock)]
825+
assert len(thinking_blocks) == 0
826+
assert len(result.blocks) == 1
827+
assert result.blocks[0].text == "Hello!"

0 commit comments

Comments
 (0)