|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import asyncio |
| 6 | +import contextvars |
5 | 7 | import random |
6 | 8 | import uuid |
7 | 9 | from contextlib import contextmanager |
@@ -400,48 +402,58 @@ async def signal_external_workflow( |
400 | 402 | def start_activity( |
401 | 403 | self, input: temporalio.worker.StartActivityInput |
402 | 404 | ) -> temporalio.workflow.ActivityHandle: |
403 | | - trace = get_trace_provider().get_current_trace() |
404 | | - span: Optional[Span] = None |
405 | | - if trace: |
406 | | - span = custom_span( |
407 | | - name="temporal:startActivity", data={"activity": input.activity} |
408 | | - ) |
409 | | - span.start(mark_as_current=True) |
410 | | - |
411 | | - set_header_from_context(input, temporalio.workflow.payload_converter()) |
412 | | - handle = self.next.start_activity(input) |
| 405 | + ctx = contextvars.copy_context() |
| 406 | + span = ctx.run( |
| 407 | + self._create_span, |
| 408 | + name="temporal:startActivity", |
| 409 | + data={"activity": input.activity}, |
| 410 | + input=input, |
| 411 | + ) |
| 412 | + handle = ctx.run(self.next.start_activity, input) |
413 | 413 | if span: |
414 | | - handle.add_done_callback(lambda _: span.finish()) # type: ignore |
| 414 | + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore |
415 | 415 | return handle |
416 | 416 |
|
417 | 417 | async def start_child_workflow( |
418 | 418 | self, input: temporalio.worker.StartChildWorkflowInput |
419 | 419 | ) -> temporalio.workflow.ChildWorkflowHandle: |
420 | | - trace = get_trace_provider().get_current_trace() |
421 | | - span: Optional[Span] = None |
422 | | - if trace: |
423 | | - span = custom_span( |
424 | | - name="temporal:startChildWorkflow", data={"workflow": input.workflow} |
425 | | - ) |
426 | | - span.start(mark_as_current=True) |
427 | | - set_header_from_context(input, temporalio.workflow.payload_converter()) |
428 | | - handle = await self.next.start_child_workflow(input) |
| 420 | + ctx = contextvars.copy_context() |
| 421 | + span = ctx.run( |
| 422 | + self._create_span, |
| 423 | + name="temporal:startChildWorkflow", |
| 424 | + data={"workflow": input.workflow}, |
| 425 | + input=input, |
| 426 | + ) |
| 427 | + handle = await ctx.run( |
| 428 | + asyncio.create_task, self.next.start_child_workflow(input) |
| 429 | + ) |
429 | 430 | if span: |
430 | | - handle.add_done_callback(lambda _: span.finish()) # type: ignore |
| 431 | + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore |
431 | 432 | return handle |
432 | 433 |
|
433 | 434 | def start_local_activity( |
434 | 435 | self, input: temporalio.worker.StartLocalActivityInput |
435 | 436 | ) -> temporalio.workflow.ActivityHandle: |
| 437 | + ctx = contextvars.copy_context() |
| 438 | + span = ctx.run( |
| 439 | + self._create_span, |
| 440 | + name="temporal:startLocalActivity", |
| 441 | + data={"activity": input.activity}, |
| 442 | + input=input, |
| 443 | + ) |
| 444 | + handle = ctx.run(self.next.start_local_activity, input) |
| 445 | + if span: |
| 446 | + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore |
| 447 | + return handle |
| 448 | + |
| 449 | + def _create_span( |
| 450 | + self, name: str, data: dict[str, Any], input: _InputWithHeaders |
| 451 | + ) -> Optional[Span]: |
436 | 452 | trace = get_trace_provider().get_current_trace() |
437 | 453 | span: Optional[Span] = None |
438 | 454 | if trace: |
439 | | - span = custom_span( |
440 | | - name="temporal:startLocalActivity", data={"activity": input.activity} |
441 | | - ) |
| 455 | + span = custom_span(name=name, data=data) |
442 | 456 | span.start(mark_as_current=True) |
| 457 | + |
443 | 458 | set_header_from_context(input, temporalio.workflow.payload_converter()) |
444 | | - handle = self.next.start_local_activity(input) |
445 | | - if span: |
446 | | - handle.add_done_callback(lambda _: span.finish()) # type: ignore |
447 | | - return handle |
| 459 | + return span |
0 commit comments