Skip to content

Commit 6ca781c

Browse files
committed
apply context management to openai_agents tracing
1 parent d1bf879 commit 6ca781c

File tree

1 file changed

+40
-28
lines changed

1 file changed

+40
-28
lines changed

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import asyncio
6+
import contextvars
57
import random
68
import uuid
79
from contextlib import contextmanager
@@ -400,48 +402,58 @@ async def signal_external_workflow(
400402
def start_activity(
401403
self, input: temporalio.worker.StartActivityInput
402404
) -> temporalio.workflow.ActivityHandle:
403-
trace = get_trace_provider().get_current_trace()
404-
span: Optional[Span] = None
405-
if trace:
406-
span = custom_span(
407-
name="temporal:startActivity", data={"activity": input.activity}
408-
)
409-
span.start(mark_as_current=True)
410-
411-
set_header_from_context(input, temporalio.workflow.payload_converter())
412-
handle = self.next.start_activity(input)
405+
ctx = contextvars.copy_context()
406+
span = ctx.run(
407+
self._create_span,
408+
name="temporal:startActivity",
409+
data={"activity": input.activity},
410+
input=input,
411+
)
412+
handle = ctx.run(self.next.start_activity, input)
413413
if span:
414-
handle.add_done_callback(lambda _: span.finish()) # type: ignore
414+
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
415415
return handle
416416

417417
async def start_child_workflow(
418418
self, input: temporalio.worker.StartChildWorkflowInput
419419
) -> temporalio.workflow.ChildWorkflowHandle:
420-
trace = get_trace_provider().get_current_trace()
421-
span: Optional[Span] = None
422-
if trace:
423-
span = custom_span(
424-
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
425-
)
426-
span.start(mark_as_current=True)
427-
set_header_from_context(input, temporalio.workflow.payload_converter())
428-
handle = await self.next.start_child_workflow(input)
420+
ctx = contextvars.copy_context()
421+
span = ctx.run(
422+
self._create_span,
423+
name="temporal:startChildWorkflow",
424+
data={"workflow": input.workflow},
425+
input=input,
426+
)
427+
handle = await ctx.run(
428+
asyncio.create_task, self.next.start_child_workflow(input)
429+
)
429430
if span:
430-
handle.add_done_callback(lambda _: span.finish()) # type: ignore
431+
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
431432
return handle
432433

433434
def start_local_activity(
434435
self, input: temporalio.worker.StartLocalActivityInput
435436
) -> temporalio.workflow.ActivityHandle:
437+
ctx = contextvars.copy_context()
438+
span = ctx.run(
439+
self._create_span,
440+
name="temporal:startLocalActivity",
441+
data={"activity": input.activity},
442+
input=input,
443+
)
444+
handle = ctx.run(self.next.start_local_activity, input)
445+
if span:
446+
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
447+
return handle
448+
449+
def _create_span(
450+
self, name: str, data: dict[str, Any], input: _InputWithHeaders
451+
) -> Optional[Span]:
436452
trace = get_trace_provider().get_current_trace()
437453
span: Optional[Span] = None
438454
if trace:
439-
span = custom_span(
440-
name="temporal:startLocalActivity", data={"activity": input.activity}
441-
)
455+
span = custom_span(name=name, data=data)
442456
span.start(mark_as_current=True)
457+
443458
set_header_from_context(input, temporalio.workflow.payload_converter())
444-
handle = self.next.start_local_activity(input)
445-
if span:
446-
handle.add_done_callback(lambda _: span.finish()) # type: ignore
447-
return handle
459+
return span

0 commit comments

Comments
 (0)