diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 94ea4876..db587657 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -102,10 +102,22 @@ def _serialize_item(item: Any) -> dict[str, Any]: class TemporalStreamingModel(Model): """Custom model implementation with streaming support.""" - def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True): - """Initialize the streaming model with OpenAI client and model name.""" - # Match the default behavior with no retries (Temporal handles retries) - self.client = AsyncOpenAI(max_retries=0) + def __init__( + self, + model_name: str = "gpt-4o", + _use_responses_api: bool = True, + openai_client: Optional[AsyncOpenAI] = None, + ): + """Initialize the streaming model with OpenAI client and model name. + + Args: + model_name: The name of the OpenAI model to use (default: "gpt-4o") + _use_responses_api: Internal flag for responses API (deprecated, always True) + openai_client: Optional custom AsyncOpenAI client. If not provided, a default + client with max_retries=0 will be created (since Temporal handles retries) + """ + # Use provided client or create default (Temporal handles retries) + self.client = openai_client if openai_client is not None else AsyncOpenAI(max_retries=0) self.model_name = model_name # Always use Responses API for all models self.use_responses_api = True @@ -114,7 +126,7 @@ def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True): agentex_client = create_async_agentex_client() self.tracer = AsyncTracer(agentex_client) - logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, tracer=initialized") + logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, custom_client={openai_client is not None}, tracer=initialized") def _non_null_or_not_given(self, value: Any) -> Any: """Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern.""" diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py index c74a816b..e3d1b380 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py @@ -23,6 +23,7 @@ TResponseInputItem, AgentOutputSchemaBase, ) +from openai import AsyncOpenAI from openai.types.responses import ResponsePromptParam from agents.models.openai_responses import OpenAIResponsesModel from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -86,17 +87,25 @@ class TemporalTracingModelProvider(OpenAIProvider): the context interceptor enabled. """ - def __init__(self, *args, **kwargs): + def __init__(self, openai_client: Optional[AsyncOpenAI] = None, **kwargs): """Initialize the tracing model provider. - Accepts all the same arguments as OpenAIProvider. + Args: + openai_client: Optional custom AsyncOpenAI client. If provided, this client + will be used for all model calls. If not provided, OpenAIProvider + will create a default client. + **kwargs: All other arguments are passed to OpenAIProvider. """ - super().__init__(*args, **kwargs) + # Pass openai_client to parent if provided + if openai_client is not None: + super().__init__(openai_client=openai_client, **kwargs) + else: + super().__init__(**kwargs) # Initialize tracer for all models agentex_client = create_async_agentex_client() self._tracer = AsyncTracer(agentex_client) - logger.info("[TemporalTracingModelProvider] Initialized with AgentEx tracer") + logger.info(f"[TemporalTracingModelProvider] Initialized with AgentEx tracer, custom_client={openai_client is not None}") @override def get_model(self, model_name: Optional[str]) -> Model: