Skip to content

Commit f39614b

Browse files
Merge pull request #134 from scaleapi/bill/reasoning_streaming_bugfix
Bill/reasoning streaming bugfix
2 parents 900f66b + 82bdd4f commit f39614b

File tree

8 files changed

+2313
-774
lines changed

8 files changed

+2313
-774
lines changed

examples/tutorials/10_agentic/10_temporal/000_hello_acp/project/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self):
3434
@override
3535
async def on_task_event_send(self, params: SendEventParams) -> None:
3636
logger.info(f"Received task message instruction: {params}")
37-
37+
3838
# 2. Echo back the client's message to show it in the UI. This is not done by default so the agent developer has full control over what is shown to the user.
3939
await adk.messages.create(task_id=params.task.id, content=params.event.content)
4040

examples/tutorials/10_agentic/10_temporal/010_agent_chat/dev.ipynb

Lines changed: 1316 additions & 207 deletions
Large diffs are not rendered by default.

examples/tutorials/10_agentic/10_temporal/010_agent_chat/project/workflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ class StateModel(BaseModel):
4848
turn_number: int
4949

5050

51-
MCP_SERVERS = [
52-
StdioServerParameters(
53-
command="npx",
54-
args=["-y", "@modelcontextprotocol/server-sequential-thinking"],
55-
),
51+
MCP_SERVERS = [ # No longer needed due to reasoning
52+
# StdioServerParameters(
53+
# command="npx",
54+
# args=["-y", "@modelcontextprotocol/server-sequential-thinking"],
55+
# ),
5656
StdioServerParameters(
5757
command="uvx",
5858
args=["openai-websearch-mcp"],

examples/tutorials/10_agentic/10_temporal/010_agent_chat/uv.lock

Lines changed: 168 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/tutorials/10_agentic/10_temporal/050_agent_chat_guardrails/project/workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,6 @@ def __init__(self):
388388
@workflow.signal(name=SignalName.RECEIVE_EVENT)
389389
@override
390390
async def on_task_event_send(self, params: SendEventParams) -> None:
391-
logger.info(f"Received task message instruction: {params}")
392391

393392
if not params.event.content:
394393
return

src/agentex/lib/core/services/adk/providers/openai.py

Lines changed: 56 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from openai.types.responses import (
1313
ResponseCompletedEvent,
1414
ResponseTextDeltaEvent,
15+
ResponseFunctionToolCall,
1516
ResponseFunctionWebSearch,
1617
ResponseOutputItemDoneEvent,
17-
ResponseReasoningTextDoneEvent,
1818
ResponseCodeInterpreterToolCall,
19-
ResponseReasoningTextDeltaEvent,
20-
ResponseReasoningSummaryTextDoneEvent,
19+
ResponseReasoningSummaryPartDoneEvent,
20+
ResponseReasoningSummaryPartAddedEvent,
2121
ResponseReasoningSummaryTextDeltaEvent,
2222
)
2323

@@ -29,7 +29,6 @@
2929
from agentex.lib.core.tracing.tracer import AsyncTracer
3030
from agentex.types.task_message_delta import (
3131
TextDelta,
32-
ReasoningContentDelta,
3332
ReasoningSummaryDelta,
3433
)
3534
from agentex.types.task_message_update import (
@@ -691,7 +690,7 @@ async def run_agent_streamed_auto_send(
691690
if self.agentex_client is None:
692691
raise ValueError("Agentex client must be provided for auto_send methods")
693692

694-
tool_call_map: dict[str, Any] = {}
693+
tool_call_map: dict[str, ResponseFunctionToolCall] = {}
695694

696695
if self.tracer is None:
697696
raise RuntimeError("Tracer not initialized - ensure tracer is provided to OpenAIService")
@@ -756,6 +755,8 @@ async def run_agent_streamed_auto_send(
756755

757756
item_id_to_streaming_context: dict[str, StreamingTaskMessageContext] = {}
758757
unclosed_item_ids: set[str] = set()
758+
# Simple string to accumulate reasoning summary
759+
current_reasoning_summary: str = ""
759760

760761
try:
761762
# Process streaming events with TaskMessage creation
@@ -848,103 +849,75 @@ async def run_agent_streamed_auto_send(
848849
type="delta",
849850
),
850851
)
851-
852-
elif isinstance(event.data, ResponseReasoningSummaryTextDeltaEvent):
853-
# Handle reasoning summary text delta
852+
# Reasoning step one: new summary part added
853+
elif isinstance(event.data, ResponseReasoningSummaryPartAddedEvent):
854+
# We need to create a new streaming context for this reasoning item
854855
item_id = event.data.item_id
855-
summary_index = event.data.summary_index
856+
857+
# Reset the reasoning summary string
858+
current_reasoning_summary = ""
859+
860+
streaming_context = self.streaming_service.streaming_task_message_context(
861+
task_id=task_id,
862+
initial_content=ReasoningContent(
863+
author="agent",
864+
summary=[],
865+
content=[],
866+
type="reasoning",
867+
style="active",
868+
),
869+
)
856870

857-
# Check if we already have a streaming context for this reasoning item
858-
if item_id not in item_id_to_streaming_context:
859-
# Create a new streaming context for this reasoning item
860-
streaming_context = self.streaming_service.streaming_task_message_context(
861-
task_id=task_id,
862-
initial_content=ReasoningContent(
863-
author="agent",
864-
summary=[],
865-
content=[],
866-
type="reasoning",
867-
style="active",
868-
),
869-
)
870-
# Open the streaming context
871-
item_id_to_streaming_context[item_id] = await streaming_context.open()
872-
unclosed_item_ids.add(item_id)
873-
else:
874-
streaming_context = item_id_to_streaming_context[item_id]
871+
# Replace the existing streaming context (if it exists)
872+
# Why do we replace? Cause all the reasoning parts use the same item_id!
873+
item_id_to_streaming_context[item_id] = await streaming_context.open()
874+
unclosed_item_ids.add(item_id)
875+
876+
# Reasoning step two: handling summary text delta
877+
elif isinstance(event.data, ResponseReasoningSummaryTextDeltaEvent):
878+
# Accumulate the delta into the string
879+
current_reasoning_summary += event.data.delta
880+
streaming_context = item_id_to_streaming_context[item_id]
875881

876882
# Stream the summary delta through the streaming service
877883
await streaming_context.stream_update(
878884
update=StreamTaskMessageDelta(
879885
parent_task_message=streaming_context.task_message,
880886
delta=ReasoningSummaryDelta(
881-
summary_index=summary_index,
887+
summary_index=event.data.summary_index,
882888
summary_delta=event.data.delta,
883889
type="reasoning_summary",
884890
),
885891
type="delta",
886892
),
887893
)
888894

889-
elif isinstance(event.data, ResponseReasoningTextDeltaEvent):
890-
# Handle reasoning content text delta
891-
item_id = event.data.item_id
892-
content_index = event.data.content_index
893-
894-
# Check if we already have a streaming context for this reasoning item
895-
if item_id not in item_id_to_streaming_context:
896-
# Create a new streaming context for this reasoning item
897-
streaming_context = self.streaming_service.streaming_task_message_context(
898-
task_id=task_id,
899-
initial_content=ReasoningContent(
900-
author="agent",
901-
summary=[],
902-
content=[],
903-
type="reasoning",
904-
style="active",
905-
),
906-
)
907-
# Open the streaming context
908-
item_id_to_streaming_context[item_id] = await streaming_context.open()
909-
unclosed_item_ids.add(item_id)
910-
else:
911-
streaming_context = item_id_to_streaming_context[item_id]
912-
913-
# Stream the content delta through the streaming service
895+
# Reasoning step three: handling summary text done, closing the streaming context
896+
elif isinstance(event.data, ResponseReasoningSummaryPartDoneEvent):
897+
# Handle reasoning summary text completion
898+
streaming_context = item_id_to_streaming_context[item_id]
899+
900+
# Create the complete reasoning content with the accumulated summary
901+
complete_reasoning_content = ReasoningContent(
902+
author="agent",
903+
summary=[current_reasoning_summary],
904+
content=[],
905+
type="reasoning",
906+
style="static",
907+
)
908+
909+
# Send a full message update with the complete reasoning content
914910
await streaming_context.stream_update(
915-
update=StreamTaskMessageDelta(
911+
update=StreamTaskMessageFull(
916912
parent_task_message=streaming_context.task_message,
917-
delta=ReasoningContentDelta(
918-
content_index=content_index,
919-
content_delta=event.data.delta,
920-
type="reasoning_content",
921-
),
922-
type="delta",
913+
content=complete_reasoning_content,
914+
type="full",
923915
),
924916
)
925-
926-
elif isinstance(event.data, ResponseReasoningSummaryTextDoneEvent):
927-
# Handle reasoning summary text completion
928-
item_id = event.data.item_id
929-
summary_index = event.data.summary_index
930-
931-
# We do NOT close the streaming context here as there can be multiple
932-
# reasoning summaries. The context will be closed when the entire
933-
# output item is done (ResponseOutputItemDoneEvent)
934-
935-
# You would think they would use the event ResponseReasoningSummaryPartDoneEvent
936-
# to close the streaming context, but they do!!!
937-
# They output both a ResponseReasoningSummaryTextDoneEvent and a ResponseReasoningSummaryPartDoneEvent
938-
# I have no idea why they do this.
939-
940-
elif isinstance(event.data, ResponseReasoningTextDoneEvent):
941-
# Handle reasoning content text completion
942-
item_id = event.data.item_id
943-
content_index = event.data.content_index
944-
945-
# We do NOT close the streaming context here as there can be multiple
946-
# reasoning content texts. The context will be closed when the entire
947-
# output item is done (ResponseOutputItemDoneEvent)
917+
918+
await streaming_context.close()
919+
unclosed_item_ids.discard(item_id)
920+
948921

949922
elif isinstance(event.data, ResponseOutputItemDoneEvent):
950923
# Handle item completion

src/agentex/lib/utils/dev_tools/async_messages.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def print_task_message(
4949

5050
# Skip empty reasoning messages
5151
if isinstance(message.content, ReasoningContent):
52-
has_summary = message.content.summary and any(s for s in message.content.summary if s)
53-
has_content = message.content.content and any(c for c in message.content.content if c)
52+
has_summary = bool(message.content.summary) and any(s for s in message.content.summary if s)
53+
has_content = bool(message.content.content) and any(c for c in message.content.content if c) if message.content.content is not None else False
5454
if not has_summary and not has_content:
5555
return
5656

@@ -135,18 +135,19 @@ def print_task_message(
135135

136136
if rich_print and console:
137137
author_color = "bright_cyan" if message.content.author == "user" else "green"
138-
title = f"[bold {author_color}]{message.content.author.upper()}[/bold {author_color}] [{timestamp}]"
139138

140-
# Use different border styles for tool messages
139+
# Use different border styles and colors for different content types
141140
if content_type == "tool_request":
142141
border_style = "yellow"
143142
elif content_type == "tool_response":
144143
border_style = "bright_green"
145144
elif content_type == "reasoning":
146145
border_style = "bright_magenta"
146+
author_color = "bright_magenta" # Also make the author text magenta
147147
else:
148148
border_style = author_color
149-
149+
150+
title = f"[bold {author_color}]{message.content.author.upper()}[/bold {author_color}] [{timestamp}]"
150151
panel = Panel(Markdown(content), title=title, border_style=border_style, width=80)
151152
console.print(panel)
152153
else:
@@ -329,7 +330,7 @@ def subscribe_to_async_task_messages(
329330

330331
# Deserialize the discriminated union TaskMessageUpdate based on the "type" field
331332
message_type = task_message_update_data.get("type", "unknown")
332-
333+
333334
# Handle different message types for streaming progress
334335
if message_type == "start":
335336
task_message_update = StreamTaskMessageStart.model_validate(task_message_update_data)
@@ -359,6 +360,9 @@ def subscribe_to_async_task_messages(
359360
if index in active_spinners:
360361
active_spinners[index].stop()
361362
del active_spinners[index]
363+
# Ensure clean line after spinner
364+
if print_messages:
365+
print()
362366

363367
if task_message_update.parent_task_message and task_message_update.parent_task_message.id:
364368
finished_message = client.messages.retrieve(task_message_update.parent_task_message.id)
@@ -373,6 +377,9 @@ def subscribe_to_async_task_messages(
373377
if index in active_spinners:
374378
active_spinners[index].stop()
375379
del active_spinners[index]
380+
# Ensure clean line after spinner
381+
if print_messages:
382+
print()
376383

377384
if task_message_update.parent_task_message and task_message_update.parent_task_message.id:
378385
finished_message = client.messages.retrieve(task_message_update.parent_task_message.id)

0 commit comments

Comments
 (0)