Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,22 @@ def _serialize_item(item: Any) -> dict[str, Any]:
class TemporalStreamingModel(Model):
"""Custom model implementation with streaming support."""

def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True):
"""Initialize the streaming model with OpenAI client and model name."""
# Match the default behavior with no retries (Temporal handles retries)
self.client = AsyncOpenAI(max_retries=0)
def __init__(
self,
model_name: str = "gpt-4o",
_use_responses_api: bool = True,
openai_client: Optional[AsyncOpenAI] = None,
):
"""Initialize the streaming model with OpenAI client and model name.

Args:
model_name: The name of the OpenAI model to use (default: "gpt-4o")
_use_responses_api: Internal flag for responses API (deprecated, always True)
openai_client: Optional custom AsyncOpenAI client. If not provided, a default
client with max_retries=0 will be created (since Temporal handles retries)
"""
# Use provided client or create default (Temporal handles retries)
self.client = openai_client if openai_client is not None else AsyncOpenAI(max_retries=0)
self.model_name = model_name
# Always use Responses API for all models
self.use_responses_api = True
Expand All @@ -114,7 +126,7 @@ def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True):
agentex_client = create_async_agentex_client()
self.tracer = AsyncTracer(agentex_client)

logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, tracer=initialized")
logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, custom_client={openai_client is not None}, tracer=initialized")

def _non_null_or_not_given(self, value: Any) -> Any:
"""Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TResponseInputItem,
AgentOutputSchemaBase,
)
from openai import AsyncOpenAI
from openai.types.responses import ResponsePromptParam
from agents.models.openai_responses import OpenAIResponsesModel
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
Expand Down Expand Up @@ -86,17 +87,25 @@ class TemporalTracingModelProvider(OpenAIProvider):
the context interceptor enabled.
"""

def __init__(self, *args, **kwargs):
def __init__(self, openai_client: Optional[AsyncOpenAI] = None, **kwargs):
"""Initialize the tracing model provider.

Accepts all the same arguments as OpenAIProvider.
Args:
openai_client: Optional custom AsyncOpenAI client. If provided, this client
will be used for all model calls. If not provided, OpenAIProvider
will create a default client.
**kwargs: All other arguments are passed to OpenAIProvider.
"""
super().__init__(*args, **kwargs)
# Pass openai_client to parent if provided
if openai_client is not None:
super().__init__(openai_client=openai_client, **kwargs)
else:
super().__init__(**kwargs)

# Initialize tracer for all models
agentex_client = create_async_agentex_client()
self._tracer = AsyncTracer(agentex_client)
logger.info("[TemporalTracingModelProvider] Initialized with AgentEx tracer")
logger.info(f"[TemporalTracingModelProvider] Initialized with AgentEx tracer, custom_client={openai_client is not None}")

@override
def get_model(self, model_name: Optional[str]) -> Model:
Expand Down