Skip to content

Commit 5fb122f

Browse files
committed
Add tracing test
1 parent 8476649 commit 5fb122f

File tree

2 files changed

+172
-8
lines changed

2 files changed

+172
-8
lines changed

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ async def start_workflow(
199199
**({"temporal:workflowId": input.id} if input.id else {}),
200200
}
201201
data = {"workflowId": input.id} if input.id else None
202-
span_name = f"temporal:startWorkflow"
202+
span_name = "temporal:startWorkflow"
203203
if get_trace_provider().get_current_trace() is None:
204204
with trace(
205205
span_name + ":" + input.workflow, metadata=metadata, group_id=input.id
@@ -218,7 +218,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
218218
**({"temporal:workflowId": input.id} if input.id else {}),
219219
}
220220
data = {"workflowId": input.id, "query": input.query}
221-
span_name = f"temporal:queryWorkflow"
221+
span_name = "temporal:queryWorkflow"
222222
if get_trace_provider().get_current_trace() is None:
223223
with trace(span_name, metadata=metadata, group_id=input.id):
224224
with custom_span(name=span_name, data=data):
@@ -237,7 +237,7 @@ async def signal_workflow(
237237
**({"temporal:workflowId": input.id} if input.id else {}),
238238
}
239239
data = {"workflowId": input.id, "signal": input.signal}
240-
span_name = f"temporal:signalWorkflow"
240+
span_name = "temporal:signalWorkflow"
241241
if get_trace_provider().get_current_trace() is None:
242242
with trace(span_name, metadata=metadata, group_id=input.id):
243243
with custom_span(name=span_name, data=data):
@@ -337,7 +337,7 @@ async def signal_child_workflow(
337337
self, input: temporalio.worker.SignalChildWorkflowInput
338338
) -> None:
339339
with custom_span(
340-
name=f"temporal:signalChildWorkflow",
340+
name="temporal:signalChildWorkflow",
341341
data={"workflowId": input.child_workflow_id},
342342
):
343343
set_header_from_context(input, temporalio.workflow.payload_converter())
@@ -347,7 +347,7 @@ async def signal_external_workflow(
347347
self, input: temporalio.worker.SignalExternalWorkflowInput
348348
) -> None:
349349
with custom_span(
350-
name=f"temporal:signalExternalWorkflow",
350+
name="temporal:signalExternalWorkflow",
351351
data={"workflowId": input.workflow_id},
352352
):
353353
set_header_from_context(input, temporalio.workflow.payload_converter())
@@ -357,7 +357,7 @@ def start_activity(
357357
self, input: temporalio.worker.StartActivityInput
358358
) -> temporalio.workflow.ActivityHandle:
359359
span = custom_span(
360-
name=f"temporal:startActivity", data={"activity": input.activity}
360+
name="temporal:startActivity", data={"activity": input.activity}
361361
)
362362
span.start(mark_as_current=True)
363363
set_header_from_context(input, temporalio.workflow.payload_converter())
@@ -369,7 +369,7 @@ async def start_child_workflow(
369369
self, input: temporalio.worker.StartChildWorkflowInput
370370
) -> temporalio.workflow.ChildWorkflowHandle:
371371
span = custom_span(
372-
name=f"temporal:startChildWorkflow", data={"workflow": input.workflow}
372+
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
373373
)
374374
span.start(mark_as_current=True)
375375
set_header_from_context(input, temporalio.workflow.payload_converter())
@@ -381,7 +381,7 @@ def start_local_activity(
381381
self, input: temporalio.worker.StartLocalActivityInput
382382
) -> temporalio.workflow.ActivityHandle:
383383
span = custom_span(
384-
name=f"temporal:startLocalActivity", data={"activity": input.activity}
384+
name="temporal:startLocalActivity", data={"activity": input.activity}
385385
)
386386
span.start(mark_as_current=True)
387387
set_header_from_context(input, temporalio.workflow.payload_converter())
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import datetime
2+
import uuid
3+
from datetime import timedelta
4+
from typing import Any, Optional, cast
5+
6+
from agents import Span, Trace, TracingProcessor
7+
from agents.tracing import get_trace_provider
8+
9+
from temporalio.client import Client
10+
from temporalio.contrib.openai_agents import (
11+
ModelActivity,
12+
OpenAIAgentsTracingInterceptor,
13+
TestModelProvider,
14+
set_open_ai_agent_temporal_overrides,
15+
)
16+
from temporalio.contrib.openai_agents._temporal_trace_provider import (
17+
TemporalTraceProvider,
18+
)
19+
from temporalio.contrib.pydantic import pydantic_data_converter
20+
from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel
21+
from tests.helpers import new_worker
22+
23+
24+
class MemoryTracingProcessor(TracingProcessor):
25+
# True for start events, false for end
26+
trace_events: list[tuple[Trace, bool]] = []
27+
span_events: list[tuple[Span, bool]] = []
28+
29+
def on_trace_start(self, trace: Trace) -> None:
30+
self.trace_events.append((trace, True))
31+
32+
def on_trace_end(self, trace: Trace) -> None:
33+
self.trace_events.append((trace, False))
34+
35+
def on_span_start(self, span: Span[Any]) -> None:
36+
self.span_events.append((span, True))
37+
38+
def on_span_end(self, span: Span[Any]) -> None:
39+
self.span_events.append((span, False))
40+
41+
def shutdown(self) -> None:
42+
pass
43+
44+
def force_flush(self) -> None:
45+
pass
46+
47+
48+
async def test_tracing(client: Client):
49+
new_config = client.config()
50+
new_config["data_converter"] = pydantic_data_converter
51+
client = Client(**new_config)
52+
53+
with set_open_ai_agent_temporal_overrides():
54+
provider = cast(TemporalTraceProvider, get_trace_provider())
55+
56+
processor = MemoryTracingProcessor()
57+
provider.set_processors([processor])
58+
59+
model_activity = ModelActivity(TestModelProvider(TestResearchModel()))
60+
async with new_worker(
61+
client,
62+
ResearchWorkflow,
63+
activities=[model_activity.invoke_model_activity],
64+
interceptors=[OpenAIAgentsTracingInterceptor()],
65+
) as worker:
66+
workflow_handle = await client.start_workflow(
67+
ResearchWorkflow.run,
68+
"Caribbean vacation spots in April, optimizing for surfing, hiking and water sports",
69+
id=f"research-workflow-{uuid.uuid4()}",
70+
task_queue=worker.task_queue,
71+
execution_timeout=timedelta(seconds=120),
72+
)
73+
result = await workflow_handle.result()
74+
75+
# There is one closed root trace
76+
assert len(processor.trace_events) == 2
77+
assert (
78+
processor.trace_events[0][0].trace_id
79+
== processor.trace_events[1][0].trace_id
80+
)
81+
assert processor.trace_events[0][1]
82+
assert not processor.trace_events[1][1]
83+
84+
def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None:
85+
assert a[0].trace_id == b[0].trace_id
86+
assert a[1]
87+
assert not b[1]
88+
89+
# Initial planner spans - There are only 3 because we don't make an actual model call
90+
paired_span(processor.span_events[0], processor.span_events[5])
91+
assert (
92+
processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent"
93+
)
94+
95+
paired_span(processor.span_events[1], processor.span_events[4])
96+
assert (
97+
processor.span_events[1][0].span_data.export().get("name")
98+
== "temporal:startActivity"
99+
)
100+
101+
paired_span(processor.span_events[2], processor.span_events[3])
102+
assert (
103+
processor.span_events[2][0].span_data.export().get("name")
104+
== "temporal:executeActivity"
105+
)
106+
107+
for span, start in processor.span_events[6:-6]:
108+
span_data = span.span_data.export()
109+
110+
# All spans should be closed
111+
if start:
112+
assert any(
113+
span.span_id == s.span_id and not s_start
114+
for (s, s_start) in processor.span_events
115+
)
116+
117+
def to_time(time: Optional[str]) -> datetime.datetime:
118+
assert time is not None
119+
return datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%f%z")
120+
121+
# Start activity is always parented to an agent
122+
if span_data.get("name") == "temporal:startActivity":
123+
parents = [
124+
s for (s, _) in processor.span_events if s.span_id == span.parent_id
125+
]
126+
assert (
127+
len(parents) == 2
128+
and parents[0].span_data.export()["type"] == "agent"
129+
)
130+
131+
assert to_time(span.started_at) >= to_time(parents[0].started_at)
132+
assert to_time(span.started_at) <= to_time(parents[1].ended_at)
133+
134+
# Execute is parented to start
135+
if span_data.get("name") == "temporal:executeActivity":
136+
parents = [
137+
s for (s, _) in processor.span_events if s.span_id == span.parent_id
138+
]
139+
assert (
140+
len(parents) == 2
141+
and parents[0].span_data.export()["name"]
142+
== "temporal:startActivity"
143+
)
144+
145+
assert to_time(span.started_at) >= to_time(parents[0].started_at)
146+
assert to_time(span.started_at) <= to_time(parents[1].ended_at)
147+
148+
# Final writer spans - There are only 3 because we don't make an actual model call
149+
paired_span(processor.span_events[-6], processor.span_events[-1])
150+
assert (
151+
processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent"
152+
)
153+
154+
paired_span(processor.span_events[-5], processor.span_events[-2])
155+
assert (
156+
processor.span_events[-5][0].span_data.export().get("name")
157+
== "temporal:startActivity"
158+
)
159+
160+
paired_span(processor.span_events[-4], processor.span_events[-3])
161+
assert (
162+
processor.span_events[-4][0].span_data.export().get("name")
163+
== "temporal:executeActivity"
164+
)

0 commit comments

Comments
 (0)