Skip to content

Commit 7c33427

Browse files
committed
Create non-streaming model
1 parent 6c4a237 commit 7c33427

File tree

5 files changed

+357
-32
lines changed

5 files changed

+357
-32
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""Temporal-aware tracing model provider.
2+
3+
This module provides model implementations that add AgentEx tracing to standard OpenAI models
4+
when running in Temporal workflows/activities. It uses context variables set by the Temporal
5+
context interceptor to access task_id, trace_id, and parent_span_id.
6+
7+
The key innovation is that these are thin wrappers around the standard OpenAI models,
8+
avoiding code duplication while adding tracing capabilities.
9+
"""
10+
11+
from typing import Optional, Union, List, Any
12+
import logging
13+
14+
from agents import (
15+
Model,
16+
ModelProvider,
17+
ModelResponse,
18+
ModelSettings,
19+
ModelTracing,
20+
Tool,
21+
Handoff,
22+
AgentOutputSchemaBase,
23+
TResponseInputItem,
24+
OpenAIProvider,
25+
)
26+
from agents.models.openai_responses import OpenAIResponsesModel
27+
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
28+
from openai import AsyncOpenAI
29+
from openai.types.responses import ResponsePromptParam
30+
31+
# Import AgentEx components
32+
from agentex.lib import adk
33+
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
34+
from agentex.lib.core.tracing.tracer import AsyncTracer
35+
36+
# Import context variables from the interceptor
37+
from agentex.lib.core.temporal.plugins.openai_agents.context_interceptor import (
38+
streaming_task_id,
39+
streaming_trace_id,
40+
streaming_parent_span_id,
41+
)
42+
43+
logger = logging.getLogger("agentex.temporal.tracing")
44+
45+
46+
class TracingModelProvider(OpenAIProvider):
47+
"""Model provider that returns OpenAI models wrapped with AgentEx tracing.
48+
49+
This provider extends the standard OpenAIProvider to return models that add
50+
tracing spans around model calls when running in Temporal activities with
51+
the context interceptor enabled.
52+
"""
53+
54+
def __init__(self, *args, **kwargs):
55+
"""Initialize the tracing model provider.
56+
57+
Accepts all the same arguments as OpenAIProvider.
58+
"""
59+
super().__init__(*args, **kwargs)
60+
61+
# Initialize tracer for all models
62+
agentex_client = create_async_agentex_client()
63+
self._tracer = AsyncTracer(agentex_client)
64+
logger.info("[TracingModelProvider] Initialized with AgentEx tracer")
65+
66+
def get_model(self, model_name: Optional[str]) -> Model:
67+
"""Get a model wrapped with tracing capabilities.
68+
69+
Args:
70+
model_name: The name of the model to use
71+
72+
Returns:
73+
A model instance wrapped with tracing
74+
"""
75+
# Get the base model from the parent provider
76+
base_model = super().get_model(model_name)
77+
78+
# Wrap with appropriate tracing wrapper based on model type
79+
if isinstance(base_model, OpenAIResponsesModel):
80+
logger.info(f"[TracingModelProvider] Wrapping OpenAIResponsesModel '{model_name}' with tracing")
81+
return TracingResponsesModel(base_model, self._tracer)
82+
elif isinstance(base_model, OpenAIChatCompletionsModel):
83+
logger.info(f"[TracingModelProvider] Wrapping OpenAIChatCompletionsModel '{model_name}' with tracing")
84+
return TracingChatCompletionsModel(base_model, self._tracer)
85+
else:
86+
logger.warning(f"[TracingModelProvider] Unknown model type, returning without tracing: {type(base_model)}")
87+
return base_model
88+
89+
90+
class TracingResponsesModel(Model):
91+
"""Wrapper for OpenAIResponsesModel that adds AgentEx tracing.
92+
93+
This is a thin wrapper that adds tracing spans around the base model's
94+
get_response() method. It reads tracing context from ContextVars set by
95+
the Temporal context interceptor.
96+
"""
97+
98+
def __init__(self, base_model: OpenAIResponsesModel, tracer: AsyncTracer):
99+
"""Initialize the tracing wrapper.
100+
101+
Args:
102+
base_model: The OpenAI Responses model to wrap
103+
tracer: The AgentEx tracer to use
104+
"""
105+
self._base_model = base_model
106+
self._tracer = tracer
107+
# Expose the model name for compatibility
108+
self.model = base_model.model
109+
110+
async def get_response(
111+
self,
112+
system_instructions: Optional[str],
113+
input: Union[str, List[TResponseInputItem]],
114+
model_settings: ModelSettings,
115+
tools: List[Tool],
116+
output_schema: Optional[AgentOutputSchemaBase],
117+
handoffs: List[Handoff],
118+
tracing: ModelTracing,
119+
previous_response_id: Optional[str] = None,
120+
conversation_id: Optional[str] = None,
121+
prompt: Optional[ResponsePromptParam] = None,
122+
**kwargs,
123+
) -> ModelResponse:
124+
"""Get a response from the model with optional tracing.
125+
126+
If tracing context is available from the interceptor, this wraps the
127+
model call in a tracing span. Otherwise, it passes through to the
128+
base model without tracing.
129+
"""
130+
# Try to get tracing context from ContextVars
131+
task_id = streaming_task_id.get()
132+
trace_id = streaming_trace_id.get()
133+
parent_span_id = streaming_parent_span_id.get()
134+
135+
# If we have tracing context, wrap with span
136+
if trace_id and parent_span_id:
137+
logger.debug(f"[TracingResponsesModel] Adding tracing span for task_id={task_id}, trace_id={trace_id}")
138+
139+
trace = self._tracer.trace(trace_id)
140+
141+
async with trace.span(
142+
parent_id=parent_span_id,
143+
name="model_get_response",
144+
input={
145+
"model": str(self.model),
146+
"has_system_instructions": system_instructions is not None,
147+
"input_type": type(input).__name__,
148+
"tools_count": len(tools) if tools else 0,
149+
"handoffs_count": len(handoffs) if handoffs else 0,
150+
"has_output_schema": output_schema is not None,
151+
"model_settings": {
152+
"temperature": model_settings.temperature,
153+
"max_tokens": model_settings.max_tokens,
154+
"reasoning": model_settings.reasoning,
155+
} if model_settings else None,
156+
},
157+
) as span:
158+
try:
159+
# Call the base model
160+
response = await self._base_model.get_response(
161+
system_instructions=system_instructions,
162+
input=input,
163+
model_settings=model_settings,
164+
tools=tools,
165+
output_schema=output_schema,
166+
handoffs=handoffs,
167+
tracing=tracing,
168+
previous_response_id=previous_response_id,
169+
conversation_id=conversation_id,
170+
prompt=prompt,
171+
**kwargs,
172+
)
173+
174+
# Add response info to span output
175+
span.output = {
176+
"response_id": getattr(response, "id", None),
177+
"model_used": getattr(response, "model", None),
178+
"usage": {
179+
"input_tokens": response.usage.input_tokens if response.usage else None,
180+
"output_tokens": response.usage.output_tokens if response.usage else None,
181+
"total_tokens": response.usage.total_tokens if response.usage else None,
182+
} if response.usage else None,
183+
}
184+
185+
return response
186+
187+
except Exception as e:
188+
# Record error in span
189+
span.error = str(e)
190+
raise
191+
else:
192+
# No tracing context, just pass through
193+
logger.debug("[TracingResponsesModel] No tracing context available, calling base model directly")
194+
return await self._base_model.get_response(
195+
system_instructions=system_instructions,
196+
input=input,
197+
model_settings=model_settings,
198+
tools=tools,
199+
output_schema=output_schema,
200+
handoffs=handoffs,
201+
tracing=tracing,
202+
previous_response_id=previous_response_id,
203+
conversation_id=conversation_id,
204+
prompt=prompt,
205+
**kwargs,
206+
)
207+
208+
209+
class TracingChatCompletionsModel(Model):
210+
"""Wrapper for OpenAIChatCompletionsModel that adds AgentEx tracing.
211+
212+
This is a thin wrapper that adds tracing spans around the base model's
213+
get_response() method. It reads tracing context from ContextVars set by
214+
the Temporal context interceptor.
215+
"""
216+
217+
def __init__(self, base_model: OpenAIChatCompletionsModel, tracer: AsyncTracer):
218+
"""Initialize the tracing wrapper.
219+
220+
Args:
221+
base_model: The OpenAI ChatCompletions model to wrap
222+
tracer: The AgentEx tracer to use
223+
"""
224+
self._base_model = base_model
225+
self._tracer = tracer
226+
# Expose the model name for compatibility
227+
self.model = base_model.model
228+
229+
async def get_response(
230+
self,
231+
system_instructions: Optional[str],
232+
input: Union[str, List[TResponseInputItem]],
233+
model_settings: ModelSettings,
234+
tools: List[Tool],
235+
output_schema: Optional[AgentOutputSchemaBase],
236+
handoffs: List[Handoff],
237+
tracing: ModelTracing,
238+
**kwargs,
239+
) -> ModelResponse:
240+
"""Get a response from the model with optional tracing.
241+
242+
If tracing context is available from the interceptor, this wraps the
243+
model call in a tracing span. Otherwise, it passes through to the
244+
base model without tracing.
245+
"""
246+
# Try to get tracing context from ContextVars
247+
task_id = streaming_task_id.get()
248+
trace_id = streaming_trace_id.get()
249+
parent_span_id = streaming_parent_span_id.get()
250+
251+
# If we have tracing context, wrap with span
252+
if trace_id and parent_span_id:
253+
logger.debug(f"[TracingChatCompletionsModel] Adding tracing span for task_id={task_id}, trace_id={trace_id}")
254+
255+
trace = self._tracer.trace(trace_id)
256+
257+
async with trace.span(
258+
parent_id=parent_span_id,
259+
name="model_get_response",
260+
input={
261+
"model": str(self.model),
262+
"has_system_instructions": system_instructions is not None,
263+
"input_type": type(input).__name__,
264+
"tools_count": len(tools) if tools else 0,
265+
"handoffs_count": len(handoffs) if handoffs else 0,
266+
"has_output_schema": output_schema is not None,
267+
"model_settings": {
268+
"temperature": model_settings.temperature,
269+
"max_tokens": model_settings.max_tokens,
270+
} if model_settings else None,
271+
},
272+
) as span:
273+
try:
274+
# Call the base model
275+
response = await self._base_model.get_response(
276+
system_instructions=system_instructions,
277+
input=input,
278+
model_settings=model_settings,
279+
tools=tools,
280+
output_schema=output_schema,
281+
handoffs=handoffs,
282+
tracing=tracing,
283+
**kwargs,
284+
)
285+
286+
# Add response info to span output
287+
span.output = {
288+
"response_id": getattr(response, "id", None),
289+
"model_used": getattr(response, "model", None),
290+
"usage": {
291+
"input_tokens": response.usage.input_tokens if response.usage else None,
292+
"output_tokens": response.usage.output_tokens if response.usage else None,
293+
"total_tokens": response.usage.total_tokens if response.usage else None,
294+
} if response.usage else None,
295+
}
296+
297+
return response
298+
299+
except Exception as e:
300+
# Record error in span
301+
span.error = str(e)
302+
raise
303+
else:
304+
# No tracing context, just pass through
305+
logger.debug("[TracingChatCompletionsModel] No tracing context available, calling base model directly")
306+
return await self._base_model.get_response(
307+
system_instructions=system_instructions,
308+
input=input,
309+
model_settings=model_settings,
310+
tools=tools,
311+
output_schema=output_schema,
312+
handoffs=handoffs,
313+
tracing=tracing,
314+
**kwargs,
315+
)

src/agentex/lib/core/temporal/plugins/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
Example:
1212
>>> from agentex.lib.core.temporal.plugins.openai_agents import (
1313
... StreamingModelProvider,
14-
... StreamingInterceptor,
14+
... TracingModelProvider,
15+
... ContextInterceptor,
1516
... )
1617
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters
1718
>>> from datetime import timedelta
@@ -28,14 +29,15 @@
2829
... )
2930
>>>
3031
>>> # Register interceptor with worker
31-
>>> interceptor = StreamingInterceptor()
32+
>>> interceptor = ContextInterceptor()
3233
>>> # Add interceptor to worker configuration
3334
"""
3435

3536
from agentex.lib.core.temporal.plugins.openai_agents import (
3637
StreamingModel,
3738
StreamingModelProvider,
38-
StreamingInterceptor,
39+
TracingModelProvider,
40+
ContextInterceptor,
3941
streaming_task_id,
4042
streaming_trace_id,
4143
streaming_parent_span_id,
@@ -46,7 +48,8 @@
4648
__all__ = [
4749
"StreamingModel",
4850
"StreamingModelProvider",
49-
"StreamingInterceptor",
51+
"TracingModelProvider",
52+
"ContextInterceptor",
5053
"streaming_task_id",
5154
"streaming_trace_id",
5255
"streaming_parent_span_id",

src/agentex/lib/core/temporal/plugins/openai_agents/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
>>> from agentex.lib.core.temporal.plugins.openai_agents import (
1414
... StreamingModelProvider,
1515
... TemporalStreamingHooks,
16-
... StreamingInterceptor,
16+
... ContextInterceptor,
1717
... )
1818
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters
1919
>>> from datetime import timedelta
@@ -31,7 +31,7 @@
3131
... )
3232
>>>
3333
>>> # 3. Register interceptor with worker
34-
>>> interceptor = StreamingInterceptor()
34+
>>> interceptor = ContextInterceptor()
3535
>>> # Add interceptor to worker configuration
3636
>>>
3737
>>> # 4. In workflow, store task_id in instance variable
@@ -55,8 +55,12 @@
5555
StreamingModel,
5656
StreamingModelProvider,
5757
)
58-
from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import (
59-
StreamingInterceptor,
58+
# Import TracingModelProvider from the adk providers module
59+
from agentex.lib.adk.providers._modules.temporal_tracing import (
60+
TracingModelProvider,
61+
)
62+
from agentex.lib.core.temporal.plugins.openai_agents.context_interceptor import (
63+
ContextInterceptor,
6064
streaming_task_id,
6165
streaming_trace_id,
6266
streaming_parent_span_id,
@@ -71,7 +75,8 @@
7175
__all__ = [
7276
"StreamingModel",
7377
"StreamingModelProvider",
74-
"StreamingInterceptor",
78+
"TracingModelProvider",
79+
"ContextInterceptor",
7580
"streaming_task_id",
7681
"streaming_trace_id",
7782
"streaming_parent_span_id",

0 commit comments

Comments
 (0)