22
33from __future__ import annotations
44
5- from contextlib import contextmanager
5+ import contextvars
6+ import asyncio
7+ from contextlib import (
8+ AbstractContextManager ,
9+ contextmanager ,
10+ )
611from dataclasses import dataclass
12+ from types import TracebackType
713from typing import (
814 Any ,
915 Callable ,
16+ ContextManager ,
1017 Dict ,
1118 Iterator ,
1219 Mapping ,
@@ -386,15 +393,30 @@ async def execute_workflow(
386393 """Implementation of
387394 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`.
388395 """
389- with self ._top_level_workflow_context (success_is_complete = True ):
390- # Entrypoint of workflow should be `server` in OTel
391- self ._completed_span (
392- f"RunWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
393- kind = opentelemetry .trace .SpanKind .SERVER ,
394- )
395- return await super ().execute_workflow (input )
396+ with self ._top_level_workflow_context (success_is_complete = True ) as ctx :
397+ return await ctx .run (asyncio .create_task , self ._execute_workflow (input ))
398+
399+ async def _execute_workflow (
400+ self , input : temporalio .worker .ExecuteWorkflowInput
401+ ) -> Any :
402+ # Entrypoint of workflow should be `server` in OTel
403+ self ._completed_span (
404+ f"RunWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
405+ kind = opentelemetry .trace .SpanKind .SERVER ,
406+ )
407+ # with self._with_complete_span(success_is_complete=True):
408+ return await super ().execute_workflow (input )
396409
397410 async def handle_signal (self , input : temporalio .worker .HandleSignalInput ) -> None :
411+ """Implementation of
412+ :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
413+ """
414+ # Create a span in the current context for the signal and link any
415+ # header given
416+ with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
417+ return await ctx .run (asyncio .create_task , self ._handle_signal (input ))
418+
419+ async def _handle_signal (self , input : temporalio .worker .HandleSignalInput ) -> None :
398420 """Implementation of
399421 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
400422 """
@@ -406,18 +428,22 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non
406428 link_context_carrier = self .payload_converter .from_payloads (
407429 [link_context_header ]
408430 )[0 ]
409- with self ._top_level_workflow_context (success_is_complete = False ):
410- self ._completed_span (
411- f"HandleSignal:{ input .signal } " ,
412- link_context_carrier = link_context_carrier ,
413- kind = opentelemetry .trace .SpanKind .SERVER ,
414- )
415- await super ().handle_signal (input )
431+ self ._completed_span (
432+ f"HandleSignal:{ input .signal } " ,
433+ link_context_carrier = link_context_carrier ,
434+ kind = opentelemetry .trace .SpanKind .SERVER ,
435+ )
436+ await super ().handle_signal (input )
416437
417438 async def handle_query (self , input : temporalio .worker .HandleQueryInput ) -> Any :
418439 """Implementation of
419440 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`.
420441 """
442+ # handle_query does not manage the contextvars.Context itself because
443+ # scheduling an asyncio task in a read only operation is not allowed.
444+ # The operation is synchronous which makes default contextvars.Context
445+ # safe.
446+
421447 # Only trace this if there is a header, and make that span the parent.
422448 # We do not put anything that happens in a query handler on the workflow
423449 # span.
@@ -448,11 +474,19 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
448474 )
449475 return await super ().handle_query (input )
450476 finally :
451- if attach_context == opentelemetry .context .get_current ():
452- opentelemetry .context .detach (token )
477+ opentelemetry .context .detach (token )
453478
454479 def handle_update_validator (
455480 self , input : temporalio .worker .HandleUpdateInput
481+ ) -> None :
482+ """Implementation of
483+ :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
484+ """
485+ with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
486+ ctx .run (self ._handle_update_validator , input )
487+
488+ def _handle_update_validator (
489+ self , input : temporalio .worker .HandleUpdateInput
456490 ) -> None :
457491 """Implementation of
458492 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
@@ -463,7 +497,6 @@ def handle_update_validator(
463497 link_context_carrier = self .payload_converter .from_payloads (
464498 [link_context_header ]
465499 )[0 ]
466- with self ._top_level_workflow_context (success_is_complete = False ):
467500 self ._completed_span (
468501 f"ValidateUpdate:{ input .update } " ,
469502 link_context_carrier = link_context_carrier ,
@@ -473,6 +506,17 @@ def handle_update_validator(
473506
474507 async def handle_update_handler (
475508 self , input : temporalio .worker .HandleUpdateInput
509+ ) -> Any :
510+ """Implementation of
511+ :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
512+ """
513+ with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
514+ return await ctx .run (
515+ asyncio .create_task , self ._handle_update_handler (input )
516+ )
517+
518+ async def _handle_update_handler (
519+ self , input : temporalio .worker .HandleUpdateInput
476520 ) -> Any :
477521 """Implementation of
478522 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
@@ -483,13 +527,12 @@ async def handle_update_handler(
483527 link_context_carrier = self .payload_converter .from_payloads (
484528 [link_context_header ]
485529 )[0 ]
486- with self ._top_level_workflow_context (success_is_complete = False ):
487- self ._completed_span (
488- f"HandleUpdate:{ input .update } " ,
489- link_context_carrier = link_context_carrier ,
490- kind = opentelemetry .trace .SpanKind .SERVER ,
491- )
492- return await super ().handle_update_handler (input )
530+ self ._completed_span (
531+ f"HandleUpdate:{ input .update } " ,
532+ link_context_carrier = link_context_carrier ,
533+ kind = opentelemetry .trace .SpanKind .SERVER ,
534+ )
535+ return await super ().handle_update_handler (input )
493536
494537 def _load_workflow_context_carrier (self ) -> Optional [_CarrierDict ]:
495538 if self ._workflow_context_carrier :
@@ -502,47 +545,13 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
502545 )[0 ]
503546 return self ._workflow_context_carrier
504547
505- @contextmanager
506548 def _top_level_workflow_context (
507549 self , * , success_is_complete : bool
508- ) -> Iterator [None ]:
509- # Load context only if there is a carrier, otherwise use empty context
510- context_carrier = self ._load_workflow_context_carrier ()
511- attach_context : opentelemetry .context .Context
512- if context_carrier :
513- attach_context = self .text_map_propagator .extract (context_carrier )
514- else :
515- attach_context = opentelemetry .context .Context ()
516- # We need to put this interceptor on the context too
517- attach_context = self ._set_on_context (attach_context )
518- # Need to know whether completed and whether there was a fail-workflow
519- # exception
520- success = False
521- exception : Optional [Exception ] = None
522- # Run under this context
523- token = opentelemetry .context .attach (attach_context )
524-
525- try :
526- yield None
527- success = True
528- except temporalio .exceptions .FailureError as err :
529- # We only record the failure errors since those are the only ones
530- # that lead to workflow completions
531- exception = err
532- raise
533- finally :
534- # Create a completed span before detaching context
535- if exception or (success and success_is_complete ):
536- self ._completed_span (
537- f"CompleteWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
538- exception = exception ,
539- kind = opentelemetry .trace .SpanKind .INTERNAL ,
540- )
541-
542- if attach_context == opentelemetry .context .get_current ():
543- opentelemetry .context .detach (token )
550+ ) -> ContextManager [contextvars .Context ]:
551+ return self ._TopLevelWorkflowContextManager (
552+ self , success_is_complete = success_is_complete
553+ )
544554
545- #
546555 def _context_to_headers (
547556 self , headers : Mapping [str , temporalio .api .common .v1 .Payload ]
548557 ) -> Mapping [str , temporalio .api .common .v1 .Payload ]:
@@ -616,6 +625,65 @@ def _set_on_context(
616625 ) -> opentelemetry .context .Context :
617626 return opentelemetry .context .set_value (_interceptor_context_key , self , context )
618627
628+ class _TopLevelWorkflowContextManager (AbstractContextManager ):
629+ def __init__ (
630+ self ,
631+ interceptor : TracingWorkflowInboundInterceptor ,
632+ * ,
633+ success_is_complete : bool ,
634+ ):
635+ self ._ctx = contextvars .copy_context ()
636+ self ._token = None
637+ self ._owner = interceptor
638+ self ._success_is_complete = success_is_complete
639+
640+ def __enter__ (self ):
641+ self ._ctx .run (self ._start )
642+ return self ._ctx
643+
644+ def __exit__ (
645+ self ,
646+ exc_type : Optional [type [BaseException ]],
647+ exc_value : Optional [BaseException ],
648+ traceback : Optional [TracebackType ], # noqa: F811
649+ ) -> bool | None :
650+ self ._ctx .run (self ._end , exc_type , exc_value , traceback )
651+
652+ def _start (self ):
653+ # Load context only if there is a carrier, otherwise use empty context
654+ context_carrier = self ._owner ._load_workflow_context_carrier ()
655+ attach_context : opentelemetry .context .Context
656+ if context_carrier :
657+ attach_context = self ._owner .text_map_propagator .extract (
658+ context_carrier
659+ )
660+ else :
661+ attach_context = opentelemetry .context .Context ()
662+ # We need to put this interceptor on the context too
663+ attach_context = self ._owner ._set_on_context (attach_context )
664+ self ._token = opentelemetry .context .attach (attach_context )
665+
666+ def _end (
667+ self ,
668+ exc_type : Optional [type [BaseException ]],
669+ exc_value : Optional [BaseException ],
670+ _traceback : Optional [TracebackType ],
671+ ):
672+ success = exc_type is None
673+ exception : Optional [temporalio .exceptions .FailureError ] = None
674+ if isinstance (exc_value , temporalio .exceptions .FailureError ):
675+ exception = exc_value
676+
677+ if (success and self ._success_is_complete ) or exception is not None :
678+ self ._owner ._completed_span (
679+ f"CompleteWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
680+ exception = exception ,
681+ kind = opentelemetry .trace .SpanKind .INTERNAL ,
682+ )
683+
684+ if self ._token :
685+ opentelemetry .context .detach (self ._token )
686+
619687
620688class _TracingWorkflowOutboundInterceptor (
621689 temporalio .worker .WorkflowOutboundInterceptor
0 commit comments