1+ """Temporal-aware tracing model provider.
2+
3+ This module provides model implementations that add AgentEx tracing to standard OpenAI models
4+ when running in Temporal workflows/activities. It uses context variables set by the Temporal
5+ context interceptor to access task_id, trace_id, and parent_span_id.
6+
7+ The key innovation is that these are thin wrappers around the standard OpenAI models,
8+ avoiding code duplication while adding tracing capabilities.
9+ """
10+
11+ from typing import Optional , Union , List , Any
12+ import logging
13+
14+ from agents import (
15+ Model ,
16+ ModelProvider ,
17+ ModelResponse ,
18+ ModelSettings ,
19+ ModelTracing ,
20+ Tool ,
21+ Handoff ,
22+ AgentOutputSchemaBase ,
23+ TResponseInputItem ,
24+ OpenAIProvider ,
25+ )
26+ from agents .models .openai_responses import OpenAIResponsesModel
27+ from agents .models .openai_chatcompletions import OpenAIChatCompletionsModel
28+ from openai import AsyncOpenAI
29+ from openai .types .responses import ResponsePromptParam
30+
31+ # Import AgentEx components
32+ from agentex .lib import adk
33+ from agentex .lib .adk .utils ._modules .client import create_async_agentex_client
34+ from agentex .lib .core .tracing .tracer import AsyncTracer
35+
36+ # Import context variables from the interceptor
37+ from agentex .lib .core .temporal .plugins .openai_agents .context_interceptor import (
38+ streaming_task_id ,
39+ streaming_trace_id ,
40+ streaming_parent_span_id ,
41+ )
42+
43+ logger = logging .getLogger ("agentex.temporal.tracing" )
44+
45+
46+ class TracingModelProvider (OpenAIProvider ):
47+ """Model provider that returns OpenAI models wrapped with AgentEx tracing.
48+
49+ This provider extends the standard OpenAIProvider to return models that add
50+ tracing spans around model calls when running in Temporal activities with
51+ the context interceptor enabled.
52+ """
53+
54+ def __init__ (self , * args , ** kwargs ):
55+ """Initialize the tracing model provider.
56+
57+ Accepts all the same arguments as OpenAIProvider.
58+ """
59+ super ().__init__ (* args , ** kwargs )
60+
61+ # Initialize tracer for all models
62+ agentex_client = create_async_agentex_client ()
63+ self ._tracer = AsyncTracer (agentex_client )
64+ logger .info ("[TracingModelProvider] Initialized with AgentEx tracer" )
65+
66+ def get_model (self , model_name : Optional [str ]) -> Model :
67+ """Get a model wrapped with tracing capabilities.
68+
69+ Args:
70+ model_name: The name of the model to use
71+
72+ Returns:
73+ A model instance wrapped with tracing
74+ """
75+ # Get the base model from the parent provider
76+ base_model = super ().get_model (model_name )
77+
78+ # Wrap with appropriate tracing wrapper based on model type
79+ if isinstance (base_model , OpenAIResponsesModel ):
80+ logger .info (f"[TracingModelProvider] Wrapping OpenAIResponsesModel '{ model_name } ' with tracing" )
81+ return TracingResponsesModel (base_model , self ._tracer )
82+ elif isinstance (base_model , OpenAIChatCompletionsModel ):
83+ logger .info (f"[TracingModelProvider] Wrapping OpenAIChatCompletionsModel '{ model_name } ' with tracing" )
84+ return TracingChatCompletionsModel (base_model , self ._tracer )
85+ else :
86+ logger .warning (f"[TracingModelProvider] Unknown model type, returning without tracing: { type (base_model )} " )
87+ return base_model
88+
89+
90+ class TracingResponsesModel (Model ):
91+ """Wrapper for OpenAIResponsesModel that adds AgentEx tracing.
92+
93+ This is a thin wrapper that adds tracing spans around the base model's
94+ get_response() method. It reads tracing context from ContextVars set by
95+ the Temporal context interceptor.
96+ """
97+
98+ def __init__ (self , base_model : OpenAIResponsesModel , tracer : AsyncTracer ):
99+ """Initialize the tracing wrapper.
100+
101+ Args:
102+ base_model: The OpenAI Responses model to wrap
103+ tracer: The AgentEx tracer to use
104+ """
105+ self ._base_model = base_model
106+ self ._tracer = tracer
107+ # Expose the model name for compatibility
108+ self .model = base_model .model
109+
110+ async def get_response (
111+ self ,
112+ system_instructions : Optional [str ],
113+ input : Union [str , List [TResponseInputItem ]],
114+ model_settings : ModelSettings ,
115+ tools : List [Tool ],
116+ output_schema : Optional [AgentOutputSchemaBase ],
117+ handoffs : List [Handoff ],
118+ tracing : ModelTracing ,
119+ previous_response_id : Optional [str ] = None ,
120+ conversation_id : Optional [str ] = None ,
121+ prompt : Optional [ResponsePromptParam ] = None ,
122+ ** kwargs ,
123+ ) -> ModelResponse :
124+ """Get a response from the model with optional tracing.
125+
126+ If tracing context is available from the interceptor, this wraps the
127+ model call in a tracing span. Otherwise, it passes through to the
128+ base model without tracing.
129+ """
130+ # Try to get tracing context from ContextVars
131+ task_id = streaming_task_id .get ()
132+ trace_id = streaming_trace_id .get ()
133+ parent_span_id = streaming_parent_span_id .get ()
134+
135+ # If we have tracing context, wrap with span
136+ if trace_id and parent_span_id :
137+ logger .debug (f"[TracingResponsesModel] Adding tracing span for task_id={ task_id } , trace_id={ trace_id } " )
138+
139+ trace = self ._tracer .trace (trace_id )
140+
141+ async with trace .span (
142+ parent_id = parent_span_id ,
143+ name = "model_get_response" ,
144+ input = {
145+ "model" : str (self .model ),
146+ "has_system_instructions" : system_instructions is not None ,
147+ "input_type" : type (input ).__name__ ,
148+ "tools_count" : len (tools ) if tools else 0 ,
149+ "handoffs_count" : len (handoffs ) if handoffs else 0 ,
150+ "has_output_schema" : output_schema is not None ,
151+ "model_settings" : {
152+ "temperature" : model_settings .temperature ,
153+ "max_tokens" : model_settings .max_tokens ,
154+ "reasoning" : model_settings .reasoning ,
155+ } if model_settings else None ,
156+ },
157+ ) as span :
158+ try :
159+ # Call the base model
160+ response = await self ._base_model .get_response (
161+ system_instructions = system_instructions ,
162+ input = input ,
163+ model_settings = model_settings ,
164+ tools = tools ,
165+ output_schema = output_schema ,
166+ handoffs = handoffs ,
167+ tracing = tracing ,
168+ previous_response_id = previous_response_id ,
169+ conversation_id = conversation_id ,
170+ prompt = prompt ,
171+ ** kwargs ,
172+ )
173+
174+ # Add response info to span output
175+ span .output = {
176+ "response_id" : getattr (response , "id" , None ),
177+ "model_used" : getattr (response , "model" , None ),
178+ "usage" : {
179+ "input_tokens" : response .usage .input_tokens if response .usage else None ,
180+ "output_tokens" : response .usage .output_tokens if response .usage else None ,
181+ "total_tokens" : response .usage .total_tokens if response .usage else None ,
182+ } if response .usage else None ,
183+ }
184+
185+ return response
186+
187+ except Exception as e :
188+ # Record error in span
189+ span .error = str (e )
190+ raise
191+ else :
192+ # No tracing context, just pass through
193+ logger .debug ("[TracingResponsesModel] No tracing context available, calling base model directly" )
194+ return await self ._base_model .get_response (
195+ system_instructions = system_instructions ,
196+ input = input ,
197+ model_settings = model_settings ,
198+ tools = tools ,
199+ output_schema = output_schema ,
200+ handoffs = handoffs ,
201+ tracing = tracing ,
202+ previous_response_id = previous_response_id ,
203+ conversation_id = conversation_id ,
204+ prompt = prompt ,
205+ ** kwargs ,
206+ )
207+
208+
209+ class TracingChatCompletionsModel (Model ):
210+ """Wrapper for OpenAIChatCompletionsModel that adds AgentEx tracing.
211+
212+ This is a thin wrapper that adds tracing spans around the base model's
213+ get_response() method. It reads tracing context from ContextVars set by
214+ the Temporal context interceptor.
215+ """
216+
217+ def __init__ (self , base_model : OpenAIChatCompletionsModel , tracer : AsyncTracer ):
218+ """Initialize the tracing wrapper.
219+
220+ Args:
221+ base_model: The OpenAI ChatCompletions model to wrap
222+ tracer: The AgentEx tracer to use
223+ """
224+ self ._base_model = base_model
225+ self ._tracer = tracer
226+ # Expose the model name for compatibility
227+ self .model = base_model .model
228+
229+ async def get_response (
230+ self ,
231+ system_instructions : Optional [str ],
232+ input : Union [str , List [TResponseInputItem ]],
233+ model_settings : ModelSettings ,
234+ tools : List [Tool ],
235+ output_schema : Optional [AgentOutputSchemaBase ],
236+ handoffs : List [Handoff ],
237+ tracing : ModelTracing ,
238+ ** kwargs ,
239+ ) -> ModelResponse :
240+ """Get a response from the model with optional tracing.
241+
242+ If tracing context is available from the interceptor, this wraps the
243+ model call in a tracing span. Otherwise, it passes through to the
244+ base model without tracing.
245+ """
246+ # Try to get tracing context from ContextVars
247+ task_id = streaming_task_id .get ()
248+ trace_id = streaming_trace_id .get ()
249+ parent_span_id = streaming_parent_span_id .get ()
250+
251+ # If we have tracing context, wrap with span
252+ if trace_id and parent_span_id :
253+ logger .debug (f"[TracingChatCompletionsModel] Adding tracing span for task_id={ task_id } , trace_id={ trace_id } " )
254+
255+ trace = self ._tracer .trace (trace_id )
256+
257+ async with trace .span (
258+ parent_id = parent_span_id ,
259+ name = "model_get_response" ,
260+ input = {
261+ "model" : str (self .model ),
262+ "has_system_instructions" : system_instructions is not None ,
263+ "input_type" : type (input ).__name__ ,
264+ "tools_count" : len (tools ) if tools else 0 ,
265+ "handoffs_count" : len (handoffs ) if handoffs else 0 ,
266+ "has_output_schema" : output_schema is not None ,
267+ "model_settings" : {
268+ "temperature" : model_settings .temperature ,
269+ "max_tokens" : model_settings .max_tokens ,
270+ } if model_settings else None ,
271+ },
272+ ) as span :
273+ try :
274+ # Call the base model
275+ response = await self ._base_model .get_response (
276+ system_instructions = system_instructions ,
277+ input = input ,
278+ model_settings = model_settings ,
279+ tools = tools ,
280+ output_schema = output_schema ,
281+ handoffs = handoffs ,
282+ tracing = tracing ,
283+ ** kwargs ,
284+ )
285+
286+ # Add response info to span output
287+ span .output = {
288+ "response_id" : getattr (response , "id" , None ),
289+ "model_used" : getattr (response , "model" , None ),
290+ "usage" : {
291+ "input_tokens" : response .usage .input_tokens if response .usage else None ,
292+ "output_tokens" : response .usage .output_tokens if response .usage else None ,
293+ "total_tokens" : response .usage .total_tokens if response .usage else None ,
294+ } if response .usage else None ,
295+ }
296+
297+ return response
298+
299+ except Exception as e :
300+ # Record error in span
301+ span .error = str (e )
302+ raise
303+ else :
304+ # No tracing context, just pass through
305+ logger .debug ("[TracingChatCompletionsModel] No tracing context available, calling base model directly" )
306+ return await self ._base_model .get_response (
307+ system_instructions = system_instructions ,
308+ input = input ,
309+ model_settings = model_settings ,
310+ tools = tools ,
311+ output_schema = output_schema ,
312+ handoffs = handoffs ,
313+ tracing = tracing ,
314+ ** kwargs ,
315+ )
0 commit comments