Skip to content

Commit 7d4cc18

Browse files
Merge pull request #196 from scaleapi/dm/client-provider
Add client to temporal streaming and tracing provider
2 parents 27e0e0a + 2a3862f commit 7d4cc18

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,22 @@ def _serialize_item(item: Any) -> dict[str, Any]:
102102
class TemporalStreamingModel(Model):
103103
"""Custom model implementation with streaming support."""
104104

105-
def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True):
106-
"""Initialize the streaming model with OpenAI client and model name."""
107-
# Match the default behavior with no retries (Temporal handles retries)
108-
self.client = AsyncOpenAI(max_retries=0)
105+
def __init__(
106+
self,
107+
model_name: str = "gpt-4o",
108+
_use_responses_api: bool = True,
109+
openai_client: Optional[AsyncOpenAI] = None,
110+
):
111+
"""Initialize the streaming model with OpenAI client and model name.
112+
113+
Args:
114+
model_name: The name of the OpenAI model to use (default: "gpt-4o")
115+
_use_responses_api: Internal flag for responses API (deprecated, always True)
116+
openai_client: Optional custom AsyncOpenAI client. If not provided, a default
117+
client with max_retries=0 will be created (since Temporal handles retries)
118+
"""
119+
# Use provided client or create default (Temporal handles retries)
120+
self.client = openai_client if openai_client is not None else AsyncOpenAI(max_retries=0)
109121
self.model_name = model_name
110122
# Always use Responses API for all models
111123
self.use_responses_api = True
@@ -114,7 +126,7 @@ def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True):
114126
agentex_client = create_async_agentex_client()
115127
self.tracer = AsyncTracer(agentex_client)
116128

117-
logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, tracer=initialized")
129+
logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, custom_client={openai_client is not None}, tracer=initialized")
118130

119131
def _non_null_or_not_given(self, value: Any) -> Any:
120132
"""Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern."""

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TResponseInputItem,
2424
AgentOutputSchemaBase,
2525
)
26+
from openai import AsyncOpenAI
2627
from openai.types.responses import ResponsePromptParam
2728
from agents.models.openai_responses import OpenAIResponsesModel
2829
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
@@ -86,17 +87,25 @@ class TemporalTracingModelProvider(OpenAIProvider):
8687
the context interceptor enabled.
8788
"""
8889

89-
def __init__(self, *args, **kwargs):
90+
def __init__(self, openai_client: Optional[AsyncOpenAI] = None, **kwargs):
9091
"""Initialize the tracing model provider.
9192
92-
Accepts all the same arguments as OpenAIProvider.
93+
Args:
94+
openai_client: Optional custom AsyncOpenAI client. If provided, this client
95+
will be used for all model calls. If not provided, OpenAIProvider
96+
will create a default client.
97+
**kwargs: All other arguments are passed to OpenAIProvider.
9398
"""
94-
super().__init__(*args, **kwargs)
99+
# Pass openai_client to parent if provided
100+
if openai_client is not None:
101+
super().__init__(openai_client=openai_client, **kwargs)
102+
else:
103+
super().__init__(**kwargs)
95104

96105
# Initialize tracer for all models
97106
agentex_client = create_async_agentex_client()
98107
self._tracer = AsyncTracer(agentex_client)
99-
logger.info("[TemporalTracingModelProvider] Initialized with AgentEx tracer")
108+
logger.info(f"[TemporalTracingModelProvider] Initialized with AgentEx tracer, custom_client={openai_client is not None}")
100109

101110
@override
102111
def get_model(self, model_name: Optional[str]) -> Model:

0 commit comments

Comments
 (0)