Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,19 @@ def _build_reasoning_param(self, model_settings: ModelSettings) -> Any:
reasoning_param = {
"effort": model_settings.reasoning.effort,
}
# Add generate_summary if specified and not None
if hasattr(model_settings.reasoning, 'generate_summary') and model_settings.reasoning.generate_summary is not None:
reasoning_param["summary"] = model_settings.reasoning.generate_summary
# Add summary if specified (check both 'summary' and 'generate_summary' for compatibility)
summary_value = None
if hasattr(model_settings.reasoning, 'summary') and model_settings.reasoning.summary is not None:
summary_value = model_settings.reasoning.summary
elif (
hasattr(model_settings.reasoning, 'generate_summary')
and model_settings.reasoning.generate_summary is not None
):
summary_value = model_settings.reasoning.generate_summary

if summary_value is not None:
reasoning_param["summary"] = summary_value

logger.debug(f"[TemporalStreamingModel] Using reasoning param: {reasoning_param}")
return reasoning_param

Expand Down Expand Up @@ -679,9 +689,34 @@ async def get_response(
output_index = getattr(event, 'output_index', 0)

if item and getattr(item, 'type', None) == 'reasoning':
logger.debug(f"[TemporalStreamingModel] Reasoning item completed")
# Don't close the context here - let it stay open for more reasoning events
# It will be closed when we send the final update or at the end
if reasoning_context and reasoning_summaries:
logger.debug(f"[TemporalStreamingModel] Reasoning itme completed, sending final update")
try:
# Send a full message update with the complete reasoning content
complete_reasoning_content = ReasoningContent(
author="agent",
summary=reasoning_summaries, # Use accumulated summaries
content=reasoning_contents if reasoning_contents else [],
type="reasoning",
style="static",
)

await reasoning_context.stream_update(
update=StreamTaskMessageFull(
parent_task_message=reasoning_context.task_message,
content=complete_reasoning_content,
type="full",
),
)

# Close the reasoning context after sending the final update
# This matches the reference implementation pattern
await reasoning_context.close()
reasoning_context = None
logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update")
except Exception as e:
logger.warning(f"Failed to send reasoning part done update: {e}")

elif item and getattr(item, 'type', None) == 'function_call':
# Function call completed - add to output
if output_index in function_calls_in_progress:
Expand All @@ -708,34 +743,8 @@ async def get_response(
current_reasoning_summary = ""

elif isinstance(event, ResponseReasoningSummaryPartDoneEvent):
# Reasoning part completed - send final update and close if this is the last part
if reasoning_context and reasoning_summaries:
logger.debug(f"[TemporalStreamingModel] Reasoning part completed, sending final update")
try:
# Send a full message update with the complete reasoning content
complete_reasoning_content = ReasoningContent(
author="agent",
summary=reasoning_summaries, # Use accumulated summaries
content=reasoning_contents if reasoning_contents else [],
type="reasoning",
style="static",
)

await reasoning_context.stream_update(
update=StreamTaskMessageFull(
parent_task_message=reasoning_context.task_message,
content=complete_reasoning_content,
type="full",
),
)

# Close the reasoning context after sending the final update
# This matches the reference implementation pattern
await reasoning_context.close()
reasoning_context = None
logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update")
except Exception as e:
logger.warning(f"Failed to send reasoning part done update: {e}")
# Reasoning part completed - ResponseOutputItemDoneEvent will handle the final update
logger.debug(f"[TemporalStreamingModel] Reasoning part completed")

elif isinstance(event, ResponseCompletedEvent):
# Response completed
Expand Down Expand Up @@ -842,10 +851,16 @@ def stream_response(self, *args, **kwargs):
class TemporalStreamingModelProvider(ModelProvider):
"""Custom model provider that returns a streaming-capable model."""

def __init__(self):
"""Initialize the provider."""
def __init__(self, openai_client: Optional[AsyncOpenAI] = None):
"""Initialize the provider.

Args:
openai_client: Optional custom AsyncOpenAI client to use for all models.
If not provided, each model will create its own default client.
"""
super().__init__()
logger.info("[TemporalStreamingModelProvider] Initialized")
self.openai_client = openai_client
logger.info(f"[TemporalStreamingModelProvider] Initialized, custom_client={openai_client is not None}")

@override
def get_model(self, model_name: Union[str, None]) -> Model:
Expand All @@ -860,5 +875,5 @@ def get_model(self, model_name: Union[str, None]) -> Model:
# Use the provided model_name or default to gpt-4o
actual_model = model_name if model_name else "gpt-4o"
logger.info(f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: {actual_model}")
model = TemporalStreamingModel(model_name=actual_model)
model = TemporalStreamingModel(model_name=actual_model, openai_client=self.openai_client)
return model