Skip to content

Commit d1bf879

Browse files
committed
Move to a class based context manager to manually manage context
1 parent 5abc5a4 commit d1bf879

File tree

2 files changed

+134
-63
lines changed

2 files changed

+134
-63
lines changed

temporalio/contrib/opentelemetry.py

Lines changed: 131 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22

33
from __future__ import annotations
44

5-
from contextlib import contextmanager
5+
import contextvars
6+
import asyncio
7+
from contextlib import (
8+
AbstractContextManager,
9+
contextmanager,
10+
)
611
from dataclasses import dataclass
12+
from types import TracebackType
713
from typing import (
814
Any,
915
Callable,
16+
ContextManager,
1017
Dict,
1118
Iterator,
1219
Mapping,
@@ -386,15 +393,30 @@ async def execute_workflow(
386393
"""Implementation of
387394
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`.
388395
"""
389-
with self._top_level_workflow_context(success_is_complete=True):
390-
# Entrypoint of workflow should be `server` in OTel
391-
self._completed_span(
392-
f"RunWorkflow:{temporalio.workflow.info().workflow_type}",
393-
kind=opentelemetry.trace.SpanKind.SERVER,
394-
)
395-
return await super().execute_workflow(input)
396+
with self._top_level_workflow_context(success_is_complete=True) as ctx:
397+
return await ctx.run(asyncio.create_task, self._execute_workflow(input))
398+
399+
async def _execute_workflow(
400+
self, input: temporalio.worker.ExecuteWorkflowInput
401+
) -> Any:
402+
# Entrypoint of workflow should be `server` in OTel
403+
self._completed_span(
404+
f"RunWorkflow:{temporalio.workflow.info().workflow_type}",
405+
kind=opentelemetry.trace.SpanKind.SERVER,
406+
)
407+
# with self._with_complete_span(success_is_complete=True):
408+
return await super().execute_workflow(input)
396409

397410
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
411+
"""Implementation of
412+
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
413+
"""
414+
# Create a span in the current context for the signal and link any
415+
# header given
416+
with self._top_level_workflow_context(success_is_complete=False) as ctx:
417+
return await ctx.run(asyncio.create_task, self._handle_signal(input))
418+
419+
async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
398420
"""Implementation of
399421
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
400422
"""
@@ -406,18 +428,22 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non
406428
link_context_carrier = self.payload_converter.from_payloads(
407429
[link_context_header]
408430
)[0]
409-
with self._top_level_workflow_context(success_is_complete=False):
410-
self._completed_span(
411-
f"HandleSignal:{input.signal}",
412-
link_context_carrier=link_context_carrier,
413-
kind=opentelemetry.trace.SpanKind.SERVER,
414-
)
415-
await super().handle_signal(input)
431+
self._completed_span(
432+
f"HandleSignal:{input.signal}",
433+
link_context_carrier=link_context_carrier,
434+
kind=opentelemetry.trace.SpanKind.SERVER,
435+
)
436+
await super().handle_signal(input)
416437

417438
async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
418439
"""Implementation of
419440
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`.
420441
"""
442+
# handle_query does not manage the contextvars.Context itself because
443+
# scheduling an asyncio task in a read only operation is not allowed.
444+
# The operation is synchronous which makes default contextvars.Context
445+
# safe.
446+
421447
# Only trace this if there is a header, and make that span the parent.
422448
# We do not put anything that happens in a query handler on the workflow
423449
# span.
@@ -448,11 +474,19 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
448474
)
449475
return await super().handle_query(input)
450476
finally:
451-
if attach_context == opentelemetry.context.get_current():
452-
opentelemetry.context.detach(token)
477+
opentelemetry.context.detach(token)
453478

454479
def handle_update_validator(
455480
self, input: temporalio.worker.HandleUpdateInput
481+
) -> None:
482+
"""Implementation of
483+
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
484+
"""
485+
with self._top_level_workflow_context(success_is_complete=False) as ctx:
486+
ctx.run(self._handle_update_validator, input)
487+
488+
def _handle_update_validator(
489+
self, input: temporalio.worker.HandleUpdateInput
456490
) -> None:
457491
"""Implementation of
458492
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
@@ -463,7 +497,6 @@ def handle_update_validator(
463497
link_context_carrier = self.payload_converter.from_payloads(
464498
[link_context_header]
465499
)[0]
466-
with self._top_level_workflow_context(success_is_complete=False):
467500
self._completed_span(
468501
f"ValidateUpdate:{input.update}",
469502
link_context_carrier=link_context_carrier,
@@ -473,6 +506,17 @@ def handle_update_validator(
473506

474507
async def handle_update_handler(
475508
self, input: temporalio.worker.HandleUpdateInput
509+
) -> Any:
510+
"""Implementation of
511+
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
512+
"""
513+
with self._top_level_workflow_context(success_is_complete=False) as ctx:
514+
return await ctx.run(
515+
asyncio.create_task, self._handle_update_handler(input)
516+
)
517+
518+
async def _handle_update_handler(
519+
self, input: temporalio.worker.HandleUpdateInput
476520
) -> Any:
477521
"""Implementation of
478522
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
@@ -483,13 +527,12 @@ async def handle_update_handler(
483527
link_context_carrier = self.payload_converter.from_payloads(
484528
[link_context_header]
485529
)[0]
486-
with self._top_level_workflow_context(success_is_complete=False):
487-
self._completed_span(
488-
f"HandleUpdate:{input.update}",
489-
link_context_carrier=link_context_carrier,
490-
kind=opentelemetry.trace.SpanKind.SERVER,
491-
)
492-
return await super().handle_update_handler(input)
530+
self._completed_span(
531+
f"HandleUpdate:{input.update}",
532+
link_context_carrier=link_context_carrier,
533+
kind=opentelemetry.trace.SpanKind.SERVER,
534+
)
535+
return await super().handle_update_handler(input)
493536

494537
def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
495538
if self._workflow_context_carrier:
@@ -502,47 +545,13 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
502545
)[0]
503546
return self._workflow_context_carrier
504547

505-
@contextmanager
506548
def _top_level_workflow_context(
507549
self, *, success_is_complete: bool
508-
) -> Iterator[None]:
509-
# Load context only if there is a carrier, otherwise use empty context
510-
context_carrier = self._load_workflow_context_carrier()
511-
attach_context: opentelemetry.context.Context
512-
if context_carrier:
513-
attach_context = self.text_map_propagator.extract(context_carrier)
514-
else:
515-
attach_context = opentelemetry.context.Context()
516-
# We need to put this interceptor on the context too
517-
attach_context = self._set_on_context(attach_context)
518-
# Need to know whether completed and whether there was a fail-workflow
519-
# exception
520-
success = False
521-
exception: Optional[Exception] = None
522-
# Run under this context
523-
token = opentelemetry.context.attach(attach_context)
524-
525-
try:
526-
yield None
527-
success = True
528-
except temporalio.exceptions.FailureError as err:
529-
# We only record the failure errors since those are the only ones
530-
# that lead to workflow completions
531-
exception = err
532-
raise
533-
finally:
534-
# Create a completed span before detaching context
535-
if exception or (success and success_is_complete):
536-
self._completed_span(
537-
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
538-
exception=exception,
539-
kind=opentelemetry.trace.SpanKind.INTERNAL,
540-
)
541-
542-
if attach_context == opentelemetry.context.get_current():
543-
opentelemetry.context.detach(token)
550+
) -> ContextManager[contextvars.Context]:
551+
return self._TopLevelWorkflowContextManager(
552+
self, success_is_complete=success_is_complete
553+
)
544554

545-
#
546555
def _context_to_headers(
547556
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
548557
) -> Mapping[str, temporalio.api.common.v1.Payload]:
@@ -616,6 +625,65 @@ def _set_on_context(
616625
) -> opentelemetry.context.Context:
617626
return opentelemetry.context.set_value(_interceptor_context_key, self, context)
618627

628+
class _TopLevelWorkflowContextManager(AbstractContextManager):
629+
def __init__(
630+
self,
631+
interceptor: TracingWorkflowInboundInterceptor,
632+
*,
633+
success_is_complete: bool,
634+
):
635+
self._ctx = contextvars.copy_context()
636+
self._token = None
637+
self._owner = interceptor
638+
self._success_is_complete = success_is_complete
639+
640+
def __enter__(self):
641+
self._ctx.run(self._start)
642+
return self._ctx
643+
644+
def __exit__(
645+
self,
646+
exc_type: Optional[type[BaseException]],
647+
exc_value: Optional[BaseException],
648+
traceback: Optional[TracebackType], # noqa: F811
649+
) -> bool | None:
650+
self._ctx.run(self._end, exc_type, exc_value, traceback)
651+
652+
def _start(self):
653+
# Load context only if there is a carrier, otherwise use empty context
654+
context_carrier = self._owner._load_workflow_context_carrier()
655+
attach_context: opentelemetry.context.Context
656+
if context_carrier:
657+
attach_context = self._owner.text_map_propagator.extract(
658+
context_carrier
659+
)
660+
else:
661+
attach_context = opentelemetry.context.Context()
662+
# We need to put this interceptor on the context too
663+
attach_context = self._owner._set_on_context(attach_context)
664+
self._token = opentelemetry.context.attach(attach_context)
665+
666+
def _end(
667+
self,
668+
exc_type: Optional[type[BaseException]],
669+
exc_value: Optional[BaseException],
670+
_traceback: Optional[TracebackType],
671+
):
672+
success = exc_type is None
673+
exception: Optional[temporalio.exceptions.FailureError] = None
674+
if isinstance(exc_value, temporalio.exceptions.FailureError):
675+
exception = exc_value
676+
677+
if (success and self._success_is_complete) or exception is not None:
678+
self._owner._completed_span(
679+
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
680+
exception=exception,
681+
kind=opentelemetry.trace.SpanKind.INTERNAL,
682+
)
683+
684+
if self._token:
685+
opentelemetry.context.detach(self._token)
686+
619687

620688
class _TracingWorkflowOutboundInterceptor(
621689
temporalio.worker.WorkflowOutboundInterceptor

temporalio/worker/_workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ async def _handle_activation(
337337
raise deadlock_exc from None
338338

339339
except Exception as err:
340+
if isinstance(err, GeneratorExit):
341+
print("###############generator exit#############")
342+
340343
if isinstance(err, _DeadlockError):
341344
err.swap_traceback()
342345

0 commit comments

Comments
 (0)