From 25eb2a207fde325476eda9fc4a1274684817cb23 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Thu, 6 Nov 2025 17:46:42 -0500 Subject: [PATCH] Move stream_context out to preserve proper message ordering --- .../models/temporal_streaming_model.py | 406 +++++++++--------- 1 file changed, 206 insertions(+), 200 deletions(-) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 130532a6..3efa4da8 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -546,219 +546,225 @@ async def get_response( # Process the stream of events from Responses API output_items = [] current_text = "" + streaming_context = None reasoning_context = None reasoning_summaries = [] reasoning_contents = [] - current_reasoning_summary = "" event_count = 0 # We expect task_id to always be provided for streaming if not task_id: raise ValueError("[TemporalStreamingModel] task_id is required for streaming model") - # Use proper async with context manager for streaming to Redis - async with adk.streaming.streaming_task_message_context( - task_id=task_id, - initial_content=TextContent( - author="agent", - content="", - format="markdown", - ), - ) as streaming_context: - # Process events from the Responses API stream - function_calls_in_progress = {} # Track function calls being streamed - - async for event in stream: - event_count += 1 - - # Log event type - logger.debug(f"[TemporalStreamingModel] Event {event_count}: {type(event).__name__}") - - # Handle different event types using isinstance for type safety - if isinstance(event, ResponseOutputItemAddedEvent): - # New output item (reasoning, function call, or message) - item = getattr(event, 'item', None) - output_index = getattr(event, 'output_index', 0) - - if item and getattr(item, 'type', None) == 'reasoning': - logger.debug(f"[TemporalStreamingModel] Starting reasoning item") - if not reasoning_context: - # Start a reasoning context for streaming reasoning to UI - reasoning_context = await adk.streaming.streaming_task_message_context( - task_id=task_id, - initial_content=ReasoningContent( - author="agent", - summary=[], - content=[], - type="reasoning", - style="active", - ), - ).__aenter__() - elif item and getattr(item, 'type', None) == 'function_call': - # Track the function call being streamed - function_calls_in_progress[output_index] = { - 'id': getattr(item, 'id', ''), - 'call_id': getattr(item, 'call_id', ''), - 'name': getattr(item, 'name', ''), - 'arguments': getattr(item, 'arguments', ''), - } - logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") - - elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): - # Stream function call arguments - output_index = getattr(event, 'output_index', 0) - delta = getattr(event, 'delta', '') - - if output_index in function_calls_in_progress: - function_calls_in_progress[output_index]['arguments'] += delta - logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") - - elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): - # Function call arguments complete - output_index = getattr(event, 'output_index', 0) - arguments = getattr(event, 'arguments', '') - - if output_index in function_calls_in_progress: - function_calls_in_progress[output_index]['arguments'] = arguments - logger.debug(f"[TemporalStreamingModel] Function call args done") - - elif isinstance(event, (ResponseReasoningTextDeltaEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent)): - # Handle text streaming - delta = getattr(event, 'delta', '') - - if isinstance(event, ResponseReasoningSummaryTextDeltaEvent) and reasoning_context: - # Stream reasoning summary deltas - these are the actual reasoning tokens! - try: - # Use ReasoningSummaryDelta for reasoning summaries - summary_index = getattr(event, 'summary_index', 0) - delta_obj = ReasoningSummaryDelta( - summary_index=summary_index, - summary_delta=delta, - type="reasoning_summary", - ) - update = StreamTaskMessageDelta( - parent_task_message=reasoning_context.task_message, - delta=delta_obj, - type="delta", - ) - await reasoning_context.stream_update(update) - # Accumulate the reasoning summary - if len(reasoning_summaries) <= summary_index: - reasoning_summaries.extend([""] * (summary_index + 1 - len(reasoning_summaries))) - reasoning_summaries[summary_index] += delta - logger.debug(f"[TemporalStreamingModel] Streamed reasoning summary: {delta[:30]}..." if len(delta) > 30 else f"[TemporalStreamingModel] Streamed reasoning summary: {delta}") - except Exception as e: - logger.warning(f"Failed to send reasoning delta: {e}") - elif isinstance(event, ResponseReasoningTextDeltaEvent) and reasoning_context: - # Regular reasoning delta (if these ever appear) - try: - delta_obj = ReasoningContentDelta( - content_index=0, - content_delta=delta, - type="reasoning_content", - ) - update = StreamTaskMessageDelta( - parent_task_message=reasoning_context.task_message, - delta=delta_obj, - type="delta", - ) - await reasoning_context.stream_update(update) - reasoning_contents.append(delta) - except Exception as e: - logger.warning(f"Failed to send reasoning delta: {e}") - elif isinstance(event, ResponseTextDeltaEvent): - # Stream regular text output - current_text += delta + # Process events from the Responses API stream + function_calls_in_progress = {} # Track function calls being streamed + + async for event in stream: + event_count += 1 + + # Log event type + logger.debug(f"[TemporalStreamingModel] Event {event_count}: {type(event).__name__}") + + # Handle different event types using isinstance for type safety + if isinstance(event, ResponseOutputItemAddedEvent): + # New output item (reasoning, function call, or message) + item = getattr(event, 'item', None) + output_index = getattr(event, 'output_index', 0) + + if item and getattr(item, 'type', None) == 'reasoning': + logger.debug(f"[TemporalStreamingModel] Starting reasoning item") + if not reasoning_context: + # Start a reasoning context for streaming reasoning to UI + reasoning_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ReasoningContent( + author="agent", + summary=[], + content=[], + type="reasoning", + style="active", + ), + ).__aenter__() + elif item and getattr(item, 'type', None) == 'function_call': + # Track the function call being streamed + function_calls_in_progress[output_index] = { + 'id': getattr(item, 'id', ''), + 'call_id': getattr(item, 'call_id', ''), + 'name': getattr(item, 'name', ''), + 'arguments': getattr(item, 'arguments', ''), + } + logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") + + elif item and getattr(item, 'type', None) == 'message': + # Track the message being streamed + streaming_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=TextContent( + author="agent", + content="", + format="markdown", + ), + ).__aenter__() + + elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): + # Stream function call arguments + output_index = getattr(event, 'output_index', 0) + delta = getattr(event, 'delta', '') + + if output_index in function_calls_in_progress: + function_calls_in_progress[output_index]['arguments'] += delta + logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") + + elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): + # Function call arguments complete + output_index = getattr(event, 'output_index', 0) + arguments = getattr(event, 'arguments', '') + + if output_index in function_calls_in_progress: + function_calls_in_progress[output_index]['arguments'] = arguments + logger.debug(f"[TemporalStreamingModel] Function call args done") + + elif isinstance(event, (ResponseReasoningTextDeltaEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent)): + # Handle text streaming + delta = getattr(event, 'delta', '') + + if isinstance(event, ResponseReasoningSummaryTextDeltaEvent) and reasoning_context: + # Stream reasoning summary deltas - these are the actual reasoning tokens! + try: + # Use ReasoningSummaryDelta for reasoning summaries + summary_index = getattr(event, 'summary_index', 0) + delta_obj = ReasoningSummaryDelta( + summary_index=summary_index, + summary_delta=delta, + type="reasoning_summary", + ) + update = StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=delta_obj, + type="delta", + ) + await reasoning_context.stream_update(update) + # Accumulate the reasoning summary + if len(reasoning_summaries) <= summary_index: + logger.debug(f"[TemporalStreamingModel] Extending reasoning summaries: {summary_index}") + reasoning_summaries.extend([""] * (summary_index + 1 - len(reasoning_summaries))) + reasoning_summaries[summary_index] += delta + logger.debug(f"[TemporalStreamingModel] Streamed reasoning summary: {delta[:30]}..." if len(delta) > 30 else f"[TemporalStreamingModel] Streamed reasoning summary: {delta}") + except Exception as e: + logger.warning(f"Failed to send reasoning delta: {e}") + elif isinstance(event, ResponseReasoningTextDeltaEvent) and reasoning_context: + # Regular reasoning delta (if these ever appear) + try: + delta_obj = ReasoningContentDelta( + content_index=0, + content_delta=delta, + type="reasoning_content", + ) + update = StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=delta_obj, + type="delta", + ) + await reasoning_context.stream_update(update) + reasoning_contents.append(delta) + except Exception as e: + logger.warning(f"Failed to send reasoning delta: {e}") + elif isinstance(event, ResponseTextDeltaEvent): + # Stream regular text output + current_text += delta + try: + delta_obj = TextDelta( + type="text", + text_delta=delta, + ) + update = StreamTaskMessageDelta( + parent_task_message=streaming_context.task_message if streaming_context else None, + delta=delta_obj, + type="delta", + ) + await streaming_context.stream_update(update) if streaming_context else None + except Exception as e: + logger.warning(f"Failed to send text delta: {e}") + + elif isinstance(event, ResponseOutputItemDoneEvent): + # Output item completed + item = getattr(event, 'item', None) + output_index = getattr(event, 'output_index', 0) + + if item and getattr(item, 'type', None) == 'reasoning': + if reasoning_context and reasoning_summaries: + logger.debug(f"[TemporalStreamingModel] Reasoning itme completed, sending final update") try: - delta_obj = TextDelta( - type="text", - text_delta=delta, + # Send a full message update with the complete reasoning content + complete_reasoning_content = ReasoningContent( + author="agent", + summary=reasoning_summaries, # Use accumulated summaries + content=reasoning_contents if reasoning_contents else [], + type="reasoning", + style="static", ) - update = StreamTaskMessageDelta( - parent_task_message=streaming_context.task_message, - delta=delta_obj, - type="delta", + + await reasoning_context.stream_update( + update=StreamTaskMessageFull( + parent_task_message=reasoning_context.task_message, + content=complete_reasoning_content, + type="full", + ), ) - await streaming_context.stream_update(update) + + # Close the reasoning context after sending the final update + # This matches the reference implementation pattern + await reasoning_context.close() + reasoning_context = None + logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update") except Exception as e: - logger.warning(f"Failed to send text delta: {e}") - - elif isinstance(event, ResponseOutputItemDoneEvent): - # Output item completed - item = getattr(event, 'item', None) - output_index = getattr(event, 'output_index', 0) - - if item and getattr(item, 'type', None) == 'reasoning': - if reasoning_context and reasoning_summaries: - logger.debug(f"[TemporalStreamingModel] Reasoning itme completed, sending final update") - try: - # Send a full message update with the complete reasoning content - complete_reasoning_content = ReasoningContent( - author="agent", - summary=reasoning_summaries, # Use accumulated summaries - content=reasoning_contents if reasoning_contents else [], - type="reasoning", - style="static", - ) - - await reasoning_context.stream_update( - update=StreamTaskMessageFull( - parent_task_message=reasoning_context.task_message, - content=complete_reasoning_content, - type="full", - ), - ) - - # Close the reasoning context after sending the final update - # This matches the reference implementation pattern - await reasoning_context.close() - reasoning_context = None - logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update") - except Exception as e: - logger.warning(f"Failed to send reasoning part done update: {e}") - - elif item and getattr(item, 'type', None) == 'function_call': - # Function call completed - add to output - if output_index in function_calls_in_progress: - call_data = function_calls_in_progress[output_index] - logger.debug(f"[TemporalStreamingModel] Function call completed: {call_data['name']}") - - # Create proper function call object - tool_call = ResponseFunctionToolCall( - id=call_data['id'], - call_id=call_data['call_id'], - type="function_call", - name=call_data['name'], - arguments=call_data['arguments'], - ) - output_items.append(tool_call) - - elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): - # New reasoning part/summary started - reset accumulator - part = getattr(event, 'part', None) - if part: - part_type = getattr(part, 'type', 'unknown') - logger.debug(f"[TemporalStreamingModel] New reasoning part: type={part_type}") - # Reset the current reasoning summary for this new part - current_reasoning_summary = "" - - elif isinstance(event, ResponseReasoningSummaryPartDoneEvent): - # Reasoning part completed - ResponseOutputItemDoneEvent will handle the final update - logger.debug(f"[TemporalStreamingModel] Reasoning part completed") - - elif isinstance(event, ResponseCompletedEvent): - # Response completed - logger.debug(f"[TemporalStreamingModel] Response completed") - response = getattr(event, 'response', None) - if response and hasattr(response, 'output'): - # Use the final output from the response - output_items = response.output - logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") - - # End of event processing loop - close any open contexts - if reasoning_context: - await reasoning_context.close() - reasoning_context = None + logger.warning(f"Failed to send reasoning part done update: {e}") + + elif item and getattr(item, 'type', None) == 'function_call': + # Function call completed - add to output + if output_index in function_calls_in_progress: + call_data = function_calls_in_progress[output_index] + logger.debug(f"[TemporalStreamingModel] Function call completed: {call_data['name']}") + + # Create proper function call object + tool_call = ResponseFunctionToolCall( + id=call_data['id'], + call_id=call_data['call_id'], + type="function_call", + name=call_data['name'], + arguments=call_data['arguments'], + ) + output_items.append(tool_call) + + elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): + # New reasoning part/summary started - reset accumulator + part = getattr(event, 'part', None) + if part: + part_type = getattr(part, 'type', 'unknown') + logger.debug(f"[TemporalStreamingModel] New reasoning part: type={part_type}") + # Reset the current reasoning summary for this new part + + elif isinstance(event, ResponseReasoningSummaryPartDoneEvent): + # Reasoning part completed - ResponseOutputItemDoneEvent will handle the final update + logger.debug(f"[TemporalStreamingModel] Reasoning part completed") + + elif isinstance(event, ResponseCompletedEvent): + # Response completed + logger.debug(f"[TemporalStreamingModel] Response completed") + response = getattr(event, 'response', None) + if response and hasattr(response, 'output'): + # Use the final output from the response + output_items = response.output + logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") + + # End of event processing loop - close any open contexts + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + if streaming_context: + await streaming_context.close() + streaming_context = None # Build the response from output items collected during streaming # Create output from the items we collected