@@ -102,10 +102,22 @@ def _serialize_item(item: Any) -> dict[str, Any]:
102102class TemporalStreamingModel (Model ):
103103 """Custom model implementation with streaming support."""
104104
105- def __init__ (self , model_name : str = "gpt-4o" , _use_responses_api : bool = True ):
106- """Initialize the streaming model with OpenAI client and model name."""
107- # Match the default behavior with no retries (Temporal handles retries)
108- self .client = AsyncOpenAI (max_retries = 0 )
105+ def __init__ (
106+ self ,
107+ model_name : str = "gpt-4o" ,
108+ _use_responses_api : bool = True ,
109+ openai_client : Optional [AsyncOpenAI ] = None ,
110+ ):
111+ """Initialize the streaming model with OpenAI client and model name.
112+
113+ Args:
114+ model_name: The name of the OpenAI model to use (default: "gpt-4o")
115+ _use_responses_api: Internal flag for responses API (deprecated, always True)
116+ openai_client: Optional custom AsyncOpenAI client. If not provided, a default
117+ client with max_retries=0 will be created (since Temporal handles retries)
118+ """
119+ # Use provided client or create default (Temporal handles retries)
120+ self .client = openai_client if openai_client is not None else AsyncOpenAI (max_retries = 0 )
109121 self .model_name = model_name
110122 # Always use Responses API for all models
111123 self .use_responses_api = True
@@ -114,7 +126,7 @@ def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True):
114126 agentex_client = create_async_agentex_client ()
115127 self .tracer = AsyncTracer (agentex_client )
116128
117- logger .info (f"[TemporalStreamingModel] Initialized model={ self .model_name } , use_responses_api={ self .use_responses_api } , tracer=initialized" )
129+ logger .info (f"[TemporalStreamingModel] Initialized model={ self .model_name } , use_responses_api={ self .use_responses_api } , custom_client= { openai_client is not None } , tracer=initialized" )
118130
119131 def _non_null_or_not_given (self , value : Any ) -> Any :
120132 """Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern."""
0 commit comments