22
33from __future__ import annotations
44
5- import asyncio
6- import contextvars
7- import sys
8- from contextlib import (
9- AbstractContextManager ,
10- contextmanager ,
11- )
5+ from contextlib import contextmanager
126from dataclasses import dataclass
13- from types import TracebackType
147from typing import (
158 Any ,
169 Callable ,
17- ContextManager ,
18- Coroutine ,
1910 Dict ,
2011 Iterator ,
2112 Mapping ,
@@ -419,38 +410,15 @@ async def execute_workflow(
419410 """Implementation of
420411 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`.
421412 """
422- with self ._top_level_workflow_context (success_is_complete = True ) as ctx :
423- if sys .version_info >= (3 , 11 ):
424- return await asyncio .create_task (
425- self ._execute_workflow (input ), context = ctx
426- )
427- else :
428- return await ctx .run (asyncio .create_task , self ._execute_workflow (input ))
429-
430- async def _execute_workflow (
431- self , input : temporalio .worker .ExecuteWorkflowInput
432- ) -> Any :
433- # Entrypoint of workflow should be `server` in OTel
434- self ._completed_span (
435- f"RunWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
436- kind = opentelemetry .trace .SpanKind .SERVER ,
437- )
438- # with self._with_complete_span(success_is_complete=True):
439- return await super ().execute_workflow (input )
413+ with self ._top_level_workflow_context (success_is_complete = True ):
414+ # Entrypoint of workflow should be `server` in OTel
415+ self ._completed_span (
416+ f"RunWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
417+ kind = opentelemetry .trace .SpanKind .SERVER ,
418+ )
419+ return await super ().execute_workflow (input )
440420
441421 async def handle_signal (self , input : temporalio .worker .HandleSignalInput ) -> None :
442- """Implementation of
443- :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
444- """
445- # Create a span in the current context for the signal and link any
446- # header given
447- with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
448- if sys .version_info >= (3 , 11 ):
449- await asyncio .create_task (self ._handle_signal (input ), context = ctx )
450- else :
451- await ctx .run (lambda : asyncio .create_task (self ._handle_signal (input )))
452-
453- async def _handle_signal (self , input : temporalio .worker .HandleSignalInput ) -> None :
454422 """Implementation of
455423 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`.
456424 """
@@ -462,22 +430,18 @@ async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> No
462430 link_context_carrier = self .payload_converter .from_payloads (
463431 [link_context_header ]
464432 )[0 ]
465- self ._completed_span (
466- f"HandleSignal:{ input .signal } " ,
467- link_context_carrier = link_context_carrier ,
468- kind = opentelemetry .trace .SpanKind .SERVER ,
469- )
470- await super ().handle_signal (input )
433+ with self ._top_level_workflow_context (success_is_complete = False ):
434+ self ._completed_span (
435+ f"HandleSignal:{ input .signal } " ,
436+ link_context_carrier = link_context_carrier ,
437+ kind = opentelemetry .trace .SpanKind .SERVER ,
438+ )
439+ await super ().handle_signal (input )
471440
472441 async def handle_query (self , input : temporalio .worker .HandleQueryInput ) -> Any :
473442 """Implementation of
474443 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`.
475444 """
476- # handle_query does not manage the contextvars.Context itself because
477- # scheduling an asyncio task in a read only operation is not allowed.
478- # The operation is synchronous which makes default contextvars.Context
479- # safe.
480-
481445 # Only trace this if there is a header, and make that span the parent.
482446 # We do not put anything that happens in a query handler on the workflow
483447 # span.
@@ -508,19 +472,11 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
508472 )
509473 return await super ().handle_query (input )
510474 finally :
511- opentelemetry .context .detach (token )
475+ if attach_context == opentelemetry .context .get_current ():
476+ opentelemetry .context .detach (token )
512477
513478 def handle_update_validator (
514479 self , input : temporalio .worker .HandleUpdateInput
515- ) -> None :
516- """Implementation of
517- :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
518- """
519- with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
520- ctx .run (self ._handle_update_validator , input )
521-
522- def _handle_update_validator (
523- self , input : temporalio .worker .HandleUpdateInput
524480 ) -> None :
525481 """Implementation of
526482 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`.
@@ -531,6 +487,7 @@ def _handle_update_validator(
531487 link_context_carrier = self .payload_converter .from_payloads (
532488 [link_context_header ]
533489 )[0 ]
490+ with self ._top_level_workflow_context (success_is_complete = False ):
534491 self ._completed_span (
535492 f"ValidateUpdate:{ input .update } " ,
536493 link_context_carrier = link_context_carrier ,
@@ -540,22 +497,6 @@ def _handle_update_validator(
540497
541498 async def handle_update_handler (
542499 self , input : temporalio .worker .HandleUpdateInput
543- ) -> Any :
544- """Implementation of
545- :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
546- """
547- with self ._top_level_workflow_context (success_is_complete = False ) as ctx :
548- if sys .version_info >= (3 , 11 ):
549- return await asyncio .create_task (
550- self ._handle_update_handler (input ), context = ctx
551- )
552- else :
553- return await ctx .run (
554- asyncio .create_task , self ._handle_update_handler (input )
555- )
556-
557- async def _handle_update_handler (
558- self , input : temporalio .worker .HandleUpdateInput
559500 ) -> Any :
560501 """Implementation of
561502 :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
@@ -566,12 +507,13 @@ async def _handle_update_handler(
566507 link_context_carrier = self .payload_converter .from_payloads (
567508 [link_context_header ]
568509 )[0 ]
569- self ._completed_span (
570- f"HandleUpdate:{ input .update } " ,
571- link_context_carrier = link_context_carrier ,
572- kind = opentelemetry .trace .SpanKind .SERVER ,
573- )
574- return await super ().handle_update_handler (input )
510+ with self ._top_level_workflow_context (success_is_complete = False ):
511+ self ._completed_span (
512+ f"HandleUpdate:{ input .update } " ,
513+ link_context_carrier = link_context_carrier ,
514+ kind = opentelemetry .trace .SpanKind .SERVER ,
515+ )
516+ return await super ().handle_update_handler (input )
575517
576518 def _load_workflow_context_carrier (self ) -> Optional [_CarrierDict ]:
577519 if self ._workflow_context_carrier :
@@ -584,13 +526,47 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]:
584526 )[0 ]
585527 return self ._workflow_context_carrier
586528
529+ @contextmanager
587530 def _top_level_workflow_context (
588531 self , * , success_is_complete : bool
589- ) -> ContextManager [contextvars .Context ]:
590- return self ._TopLevelWorkflowContextManager (
591- self , success_is_complete = success_is_complete
592- )
532+ ) -> Iterator [None ]:
533+ # Load context only if there is a carrier, otherwise use empty context
534+ context_carrier = self ._load_workflow_context_carrier ()
535+ attach_context : opentelemetry .context .Context
536+ if context_carrier :
537+ attach_context = self .text_map_propagator .extract (context_carrier )
538+ else :
539+ attach_context = opentelemetry .context .Context ()
540+ # We need to put this interceptor on the context too
541+ attach_context = self ._set_on_context (attach_context )
542+ # Need to know whether completed and whether there was a fail-workflow
543+ # exception
544+ success = False
545+ exception : Optional [Exception ] = None
546+ # Run under this context
547+ token = opentelemetry .context .attach (attach_context )
548+
549+ try :
550+ yield None
551+ success = True
552+ except temporalio .exceptions .FailureError as err :
553+ # We only record the failure errors since those are the only ones
554+ # that lead to workflow completions
555+ exception = err
556+ raise
557+ finally :
558+ # Create a completed span before detaching context
559+ if exception or (success and success_is_complete ):
560+ self ._completed_span (
561+ f"CompleteWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
562+ exception = exception ,
563+ kind = opentelemetry .trace .SpanKind .INTERNAL ,
564+ )
565+
566+ if attach_context == opentelemetry .context .get_current ():
567+ opentelemetry .context .detach (token )
593568
569+ #
594570 def _context_to_headers (
595571 self , headers : Mapping [str , temporalio .api .common .v1 .Payload ]
596572 ) -> Mapping [str , temporalio .api .common .v1 .Payload ]:
@@ -664,65 +640,6 @@ def _set_on_context(
664640 ) -> opentelemetry .context .Context :
665641 return opentelemetry .context .set_value (_interceptor_context_key , self , context )
666642
667- class _TopLevelWorkflowContextManager (AbstractContextManager ):
668- def __init__ (
669- self ,
670- interceptor : TracingWorkflowInboundInterceptor ,
671- * ,
672- success_is_complete : bool ,
673- ):
674- self ._ctx = contextvars .copy_context ()
675- self ._token : Optional [contextvars .Token ] = None
676- self ._owner = interceptor
677- self ._success_is_complete = success_is_complete
678-
679- def __enter__ (self ):
680- self ._ctx .run (self ._start )
681- return self ._ctx
682-
683- def __exit__ (
684- self ,
685- exc_type : Optional [type [BaseException ]],
686- exc_value : Optional [BaseException ],
687- traceback : Optional [TracebackType ], # noqa: F811
688- ):
689- self ._ctx .run (self ._end , exc_type , exc_value , traceback )
690-
691- def _start (self ):
692- # Load context only if there is a carrier, otherwise use empty context
693- context_carrier = self ._owner ._load_workflow_context_carrier ()
694- attach_context : opentelemetry .context .Context
695- if context_carrier :
696- attach_context = self ._owner .text_map_propagator .extract (
697- context_carrier
698- )
699- else :
700- attach_context = opentelemetry .context .Context ()
701- # We need to put this interceptor on the context too
702- attach_context = self ._owner ._set_on_context (attach_context )
703- self ._token = opentelemetry .context .attach (attach_context )
704-
705- def _end (
706- self ,
707- exc_type : Optional [type [BaseException ]],
708- exc_value : Optional [BaseException ],
709- _traceback : Optional [TracebackType ],
710- ):
711- success = exc_type is None
712- exception : Optional [temporalio .exceptions .FailureError ] = None
713- if isinstance (exc_value , temporalio .exceptions .FailureError ):
714- exception = exc_value
715-
716- if (success and self ._success_is_complete ) or exception is not None :
717- self ._owner ._completed_span (
718- f"CompleteWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
719- exception = exception ,
720- kind = opentelemetry .trace .SpanKind .INTERNAL ,
721- )
722-
723- if self ._token :
724- opentelemetry .context .detach (self ._token )
725-
726643
727644class _TracingWorkflowOutboundInterceptor (
728645 temporalio .worker .WorkflowOutboundInterceptor
0 commit comments