Skip to content

Commit fe8b0a5

Browse files
committed
revert manual context management explorations
1 parent 6deb5b1 commit fe8b0a5

File tree

3 files changed

+93
-198
lines changed

3 files changed

+93
-198
lines changed

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
from __future__ import annotations
44

5-
import asyncio
6-
import contextvars
75
import random
8-
import sys
96
import uuid
107
from contextlib import contextmanager
118
from typing import Any, Mapping, Optional, Protocol, Type
@@ -403,63 +400,48 @@ async def signal_external_workflow(
403400
def start_activity(
404401
self, input: temporalio.worker.StartActivityInput
405402
) -> temporalio.workflow.ActivityHandle:
406-
ctx = contextvars.copy_context()
407-
span = ctx.run(
408-
self._create_span,
409-
name="temporal:startActivity",
410-
data={"activity": input.activity},
411-
input=input,
412-
)
413-
handle = ctx.run(self.next.start_activity, input)
403+
trace = get_trace_provider().get_current_trace()
404+
span: Optional[Span] = None
405+
if trace:
406+
span = custom_span(
407+
name="temporal:startActivity", data={"activity": input.activity}
408+
)
409+
span.start(mark_as_current=True)
410+
411+
set_header_from_context(input, temporalio.workflow.payload_converter())
412+
handle = self.next.start_activity(input)
414413
if span:
415-
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
414+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
416415
return handle
417416

418417
async def start_child_workflow(
419418
self, input: temporalio.worker.StartChildWorkflowInput
420419
) -> temporalio.workflow.ChildWorkflowHandle:
421-
ctx = contextvars.copy_context()
422-
span = ctx.run(
423-
self._create_span,
424-
name="temporal:startChildWorkflow",
425-
data={"workflow": input.workflow},
426-
input=input,
427-
)
428-
if sys.version_info >= (3, 11):
429-
handle: temporalio.workflow.ChildWorkflowHandle = await asyncio.create_task(
430-
self.next.start_child_workflow(input), context=ctx
431-
)
432-
else:
433-
handle: temporalio.workflow.ChildWorkflowHandle = await ctx.run(
434-
lambda: asyncio.create_task(self.next.start_child_workflow(input))
420+
trace = get_trace_provider().get_current_trace()
421+
span: Optional[Span] = None
422+
if trace:
423+
span = custom_span(
424+
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
435425
)
426+
span.start(mark_as_current=True)
427+
set_header_from_context(input, temporalio.workflow.payload_converter())
428+
handle = await self.next.start_child_workflow(input)
436429
if span:
437-
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
430+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
438431
return handle
439432

440433
def start_local_activity(
441434
self, input: temporalio.worker.StartLocalActivityInput
442435
) -> temporalio.workflow.ActivityHandle:
443-
ctx = contextvars.copy_context()
444-
span = ctx.run(
445-
self._create_span,
446-
name="temporal:startLocalActivity",
447-
data={"activity": input.activity},
448-
input=input,
449-
)
450-
handle = ctx.run(self.next.start_local_activity, input)
451-
if span:
452-
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
453-
return handle
454-
455-
def _create_span(
456-
self, name: str, data: dict[str, Any], input: _InputWithHeaders
457-
) -> Optional[Span]:
458436
trace = get_trace_provider().get_current_trace()
459437
span: Optional[Span] = None
460438
if trace:
461-
span = custom_span(name=name, data=data)
439+
span = custom_span(
440+
name="temporal:startLocalActivity", data={"activity": input.activity}
441+
)
462442
span.start(mark_as_current=True)
463-
464443
set_header_from_context(input, temporalio.workflow.payload_converter())
465-
return span
444+
handle = self.next.start_local_activity(input)
445+
if span:
446+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
447+
return handle

temporalio/contrib/opentelemetry.py

Lines changed: 63 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,11 @@
22

33
from __future__ import annotations
44

5-
import asyncio
6-
import contextvars
7-
import sys
8-
from contextlib import (
9-
AbstractContextManager,
10-
contextmanager,
11-
)
5+
from contextlib import contextmanager
126
from dataclasses import dataclass
13-
from types import TracebackType
147
from typing import (
158
Any,
169
Callable,
17-
ContextManager,
18-
Coroutine,
1910
Dict,
2011
Iterator,
2112
Mapping,
@@ -419,38 +410,15 @@ async def execute_workflow(
419410
"""Implementation of
420411
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`.
421412
"""
422-
with self._top_level_workflow_context(success_is_complete=True) as ctx:
423-
if sys.version_info >= (3, 11):
424-
return await asyncio.create_task(
425-
self._execute_workflow(input), context=ctx
426-
)
427-
else:
428-
return await ctx.run(asyncio.create_task, self._execute_workflow(input))
429-
430-
async def _execute_workflow(
431-
self, input: temporalio.worker.ExecuteWorkflowInput
432-
) -> Any:
433-
# Entrypoint of workflow should be `server` in OTel
434-
self._completed_span(
435-
f"RunWorkflow:{temporalio.workflow.info().workflow_type}",
436-
kind=opentelemetry.trace.SpanKind.SERVER,
437-
)
438-
# with self._with_complete_span(success_is_complete=True):
439-
return await super().execute_workflow(input)
413+
with self._top_level_workflow_context(success_is_complete=True):
414+
# Entrypoint of workflow should be `server` in OTel
415+
self._completed_span(
416+
f"RunWorkflow:{temporalio.workflow.info().workflow_type}",
417+
kind=opentelemetry.trace.SpanKind.SERVER,
418+
)
419+
return await super().execute_workflow(input)
440420

441421
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
442-
"""Implementation of
443-
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
444-
"""
445-
# Create a span in the current context for the signal and link any
446-
# header given
447-
with self._top_level_workflow_context(success_is_complete=False) as ctx:
448-
if sys.version_info >= (3, 11):
449-
await asyncio.create_task(self._handle_signal(input), context=ctx)
450-
else:
451-
await ctx.run(lambda: asyncio.create_task(self._handle_signal(input)))
452-
453-
async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
454422
"""Implementation of
455423
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
456424
"""
@@ -462,22 +430,18 @@ async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> No
462430
link_context_carrier = self.payload_converter.from_payloads(
463431
[link_context_header]
464432
)[0]
465-
self._completed_span(
466-
f"HandleSignal:{input.signal}",
467-
link_context_carrier=link_context_carrier,
468-
kind=opentelemetry.trace.SpanKind.SERVER,
469-
)
470-
await super().handle_signal(input)
433+
with self._top_level_workflow_context(success_is_complete=False):
434+
self._completed_span(
435+
f"HandleSignal:{input.signal}",
436+
link_context_carrier=link_context_carrier,
437+
kind=opentelemetry.trace.SpanKind.SERVER,
438+
)
439+
await super().handle_signal(input)
471440

472441
async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
473442
"""Implementation of
474443
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`.
475444
"""
476-
# handle_query does not manage the contextvars.Context itself because
477-
# scheduling an asyncio task in a read only operation is not allowed.
478-
# The operation is synchronous which makes default contextvars.Context
479-
# safe.
480-
481445
# Only trace this if there is a header, and make that span the parent.
482446
# We do not put anything that happens in a query handler on the workflow
483447
# span.
@@ -508,19 +472,11 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
508472
)
509473
return await super().handle_query(input)
510474
finally:
511-
opentelemetry.context.detach(token)
475+
if attach_context == opentelemetry.context.get_current():
476+
opentelemetry.context.detach(token)
512477

513478
def handle_update_validator(
514479
self, input: temporalio.worker.HandleUpdateInput
515-
) -> None:
516-
"""Implementation of
517-
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
518-
"""
519-
with self._top_level_workflow_context(success_is_complete=False) as ctx:
520-
ctx.run(self._handle_update_validator, input)
521-
522-
def _handle_update_validator(
523-
self, input: temporalio.worker.HandleUpdateInput
524480
) -> None:
525481
"""Implementation of
526482
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
@@ -531,6 +487,7 @@ def _handle_update_validator(
531487
link_context_carrier = self.payload_converter.from_payloads(
532488
[link_context_header]
533489
)[0]
490+
with self._top_level_workflow_context(success_is_complete=False):
534491
self._completed_span(
535492
f"ValidateUpdate:{input.update}",
536493
link_context_carrier=link_context_carrier,
@@ -540,22 +497,6 @@ def _handle_update_validator(
540497

541498
async def handle_update_handler(
542499
self, input: temporalio.worker.HandleUpdateInput
543-
) -> Any:
544-
"""Implementation of
545-
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
546-
"""
547-
with self._top_level_workflow_context(success_is_complete=False) as ctx:
548-
if sys.version_info >= (3, 11):
549-
return await asyncio.create_task(
550-
self._handle_update_handler(input), context=ctx
551-
)
552-
else:
553-
return await ctx.run(
554-
asyncio.create_task, self._handle_update_handler(input)
555-
)
556-
557-
async def _handle_update_handler(
558-
self, input: temporalio.worker.HandleUpdateInput
559500
) -> Any:
560501
"""Implementation of
561502
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
@@ -566,12 +507,13 @@ async def _handle_update_handler(
566507
link_context_carrier = self.payload_converter.from_payloads(
567508
[link_context_header]
568509
)[0]
569-
self._completed_span(
570-
f"HandleUpdate:{input.update}",
571-
link_context_carrier=link_context_carrier,
572-
kind=opentelemetry.trace.SpanKind.SERVER,
573-
)
574-
return await super().handle_update_handler(input)
510+
with self._top_level_workflow_context(success_is_complete=False):
511+
self._completed_span(
512+
f"HandleUpdate:{input.update}",
513+
link_context_carrier=link_context_carrier,
514+
kind=opentelemetry.trace.SpanKind.SERVER,
515+
)
516+
return await super().handle_update_handler(input)
575517

576518
def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
577519
if self._workflow_context_carrier:
@@ -584,13 +526,47 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
584526
)[0]
585527
return self._workflow_context_carrier
586528

529+
@contextmanager
587530
def _top_level_workflow_context(
588531
self, *, success_is_complete: bool
589-
) -> ContextManager[contextvars.Context]:
590-
return self._TopLevelWorkflowContextManager(
591-
self, success_is_complete=success_is_complete
592-
)
532+
) -> Iterator[None]:
533+
# Load context only if there is a carrier, otherwise use empty context
534+
context_carrier = self._load_workflow_context_carrier()
535+
attach_context: opentelemetry.context.Context
536+
if context_carrier:
537+
attach_context = self.text_map_propagator.extract(context_carrier)
538+
else:
539+
attach_context = opentelemetry.context.Context()
540+
# We need to put this interceptor on the context too
541+
attach_context = self._set_on_context(attach_context)
542+
# Need to know whether completed and whether there was a fail-workflow
543+
# exception
544+
success = False
545+
exception: Optional[Exception] = None
546+
# Run under this context
547+
token = opentelemetry.context.attach(attach_context)
548+
549+
try:
550+
yield None
551+
success = True
552+
except temporalio.exceptions.FailureError as err:
553+
# We only record the failure errors since those are the only ones
554+
# that lead to workflow completions
555+
exception = err
556+
raise
557+
finally:
558+
# Create a completed span before detaching context
559+
if exception or (success and success_is_complete):
560+
self._completed_span(
561+
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
562+
exception=exception,
563+
kind=opentelemetry.trace.SpanKind.INTERNAL,
564+
)
565+
566+
if attach_context == opentelemetry.context.get_current():
567+
opentelemetry.context.detach(token)
593568

569+
#
594570
def _context_to_headers(
595571
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
596572
) -> Mapping[str, temporalio.api.common.v1.Payload]:
@@ -664,65 +640,6 @@ def _set_on_context(
664640
) -> opentelemetry.context.Context:
665641
return opentelemetry.context.set_value(_interceptor_context_key, self, context)
666642

667-
class _TopLevelWorkflowContextManager(AbstractContextManager):
668-
def __init__(
669-
self,
670-
interceptor: TracingWorkflowInboundInterceptor,
671-
*,
672-
success_is_complete: bool,
673-
):
674-
self._ctx = contextvars.copy_context()
675-
self._token: Optional[contextvars.Token] = None
676-
self._owner = interceptor
677-
self._success_is_complete = success_is_complete
678-
679-
def __enter__(self):
680-
self._ctx.run(self._start)
681-
return self._ctx
682-
683-
def __exit__(
684-
self,
685-
exc_type: Optional[type[BaseException]],
686-
exc_value: Optional[BaseException],
687-
traceback: Optional[TracebackType], # noqa: F811
688-
):
689-
self._ctx.run(self._end, exc_type, exc_value, traceback)
690-
691-
def _start(self):
692-
# Load context only if there is a carrier, otherwise use empty context
693-
context_carrier = self._owner._load_workflow_context_carrier()
694-
attach_context: opentelemetry.context.Context
695-
if context_carrier:
696-
attach_context = self._owner.text_map_propagator.extract(
697-
context_carrier
698-
)
699-
else:
700-
attach_context = opentelemetry.context.Context()
701-
# We need to put this interceptor on the context too
702-
attach_context = self._owner._set_on_context(attach_context)
703-
self._token = opentelemetry.context.attach(attach_context)
704-
705-
def _end(
706-
self,
707-
exc_type: Optional[type[BaseException]],
708-
exc_value: Optional[BaseException],
709-
_traceback: Optional[TracebackType],
710-
):
711-
success = exc_type is None
712-
exception: Optional[temporalio.exceptions.FailureError] = None
713-
if isinstance(exc_value, temporalio.exceptions.FailureError):
714-
exception = exc_value
715-
716-
if (success and self._success_is_complete) or exception is not None:
717-
self._owner._completed_span(
718-
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
719-
exception=exception,
720-
kind=opentelemetry.trace.SpanKind.INTERNAL,
721-
)
722-
723-
if self._token:
724-
opentelemetry.context.detach(self._token)
725-
726643

727644
class _TracingWorkflowOutboundInterceptor(
728645
temporalio.worker.WorkflowOutboundInterceptor

0 commit comments

Comments
 (0)