Skip to content

Commit c9b4b13

Browse files
committed
wip: test reflection isn't broken
1 parent 6d703c0 commit c9b4b13

File tree

3 files changed

+104
-30
lines changed

3 files changed

+104
-30
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import dataclasses
2+
import logging
3+
import typing
4+
5+
import temporalio.worker
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
@dataclasses.dataclass(frozen=True)
11+
class InterceptedActivity:
12+
class_name: str
13+
name: typing.Optional[str]
14+
qualname: typing.Optional[str]
15+
module: typing.Optional[str]
16+
annotations: typing.Dict[str, typing.Any]
17+
docstring: typing.Optional[str]
18+
19+
20+
class ReflectionInterceptor(temporalio.worker.Interceptor):
21+
"""Interceptor to check we haven't broken reflection when wrapping the activity."""
22+
23+
def __init__(self) -> None:
24+
self._intercepted_activities: list[InterceptedActivity] = []
25+
26+
def get_intercepted_activities(self) -> typing.List[InterceptedActivity]:
27+
"""Get the list of intercepted activities."""
28+
return self._intercepted_activities
29+
30+
def intercept_activity(
31+
self, next: temporalio.worker.ActivityInboundInterceptor
32+
) -> temporalio.worker.ActivityInboundInterceptor:
33+
"""Method called for intercepting an activity.
34+
35+
Args:
36+
next: The underlying inbound interceptor this interceptor should
37+
delegate to.
38+
39+
Returns:
40+
The new interceptor that will be used to for the activity.
41+
"""
42+
return _ReflectionActivityInboundInterceptor(next, self)
43+
44+
45+
class _ReflectionActivityInboundInterceptor(
46+
temporalio.worker.ActivityInboundInterceptor
47+
):
48+
def __init__(
49+
self,
50+
next: temporalio.worker.ActivityInboundInterceptor,
51+
root: ReflectionInterceptor,
52+
) -> None:
53+
super().__init__(next)
54+
self.root = root
55+
56+
async def execute_activity(
57+
self, input: temporalio.worker.ExecuteActivityInput
58+
) -> typing.Any:
59+
"""Called to invoke the activity."""
60+
61+
try:
62+
self.root._intercepted_activities.append(
63+
InterceptedActivity(
64+
class_name=input.fn.__class__.__name__,
65+
name=getattr(input.fn, "__name__", None),
66+
qualname=getattr(input.fn, "__qualname__", None),
67+
module=getattr(input.fn, "__module__", None),
68+
docstring=getattr(input.fn, "__doc__", None),
69+
annotations=getattr(input.fn, "__annotations__", {}),
70+
)
71+
)
72+
except AttributeError:
73+
logger.exception(
74+
"Activity function does not have expected attributes, skipping reflection."
75+
)
76+
77+
return await self.next.execute_activity(input)

tests/contrib/opentelemetry/helpers/tracing.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,18 @@
11
from __future__ import annotations
22

3-
import asyncio
4-
import concurrent.futures
5-
import logging
63
import multiprocessing
74
import multiprocessing.managers
85
import threading
96
import typing
10-
import uuid
117
from dataclasses import dataclass
12-
from datetime import timedelta
138
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
149

1510
import opentelemetry.trace
16-
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
11+
from opentelemetry.sdk.trace import ReadableSpan
1712
from opentelemetry.sdk.trace.export import (
18-
SimpleSpanProcessor,
1913
SpanExporter,
2014
SpanExportResult,
2115
)
22-
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
23-
from opentelemetry.trace import get_current_span, get_tracer
24-
25-
from temporalio import activity, workflow
26-
from temporalio.client import Client
27-
from temporalio.common import RetryPolicy
28-
from temporalio.contrib.opentelemetry import TracingInterceptor
29-
from temporalio.contrib.opentelemetry import workflow as otel_workflow
30-
from temporalio.testing import WorkflowEnvironment
31-
from temporalio.worker import SharedStateManager, UnsandboxedWorkflowRunner, Worker
3216

3317

3418
@dataclass(frozen=True)
@@ -85,14 +69,9 @@ def from_readable_span(cls, span: ReadableSpan) -> "SerialisableSpan":
8569
)
8670

8771

88-
SerialisableSpanListProxy: typing.TypeAlias = multiprocessing.managers.ListProxy[
89-
SerialisableSpan
90-
]
91-
92-
9372
def make_span_proxy_list(
9473
manager: multiprocessing.managers.SyncManager,
95-
) -> SerialisableSpanListProxy:
74+
) -> multiprocessing.managers.ListProxy[SerialisableSpan]:
9675
"""Create a list proxy to share `SerialisableSpan` across processes."""
9776
return manager.list()
9877

@@ -110,7 +89,9 @@ class _ListProxySpanExporter(SpanExporter):
11089
into a single trace.
11190
"""
11291

113-
def __init__(self, finished_spans: SerialisableSpanListProxy) -> None:
92+
def __init__(
93+
self, finished_spans: multiprocessing.managers.ListProxy[SerialisableSpan]
94+
) -> None:
11495
self._finished_spans = finished_spans
11596
self._stopped = False
11697
self._lock = threading.Lock()

tests/contrib/test_opentelemetry.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626
from temporalio.contrib.opentelemetry import workflow as otel_workflow
2727
from temporalio.testing import WorkflowEnvironment
2828
from temporalio.worker import SharedStateManager, UnsandboxedWorkflowRunner, Worker
29+
from tests.contrib.opentelemetry.helpers.reflection_interceptor import (
30+
InterceptedActivity,
31+
ReflectionInterceptor,
32+
)
2933
from tests.contrib.opentelemetry.helpers.tracing import (
3034
SerialisableSpan,
3135
_ListProxySpanExporter,
3236
dump_spans,
3337
make_span_proxy_list,
34-
SerialisableSpanListProxy,
3538
)
3639

3740
# Passing through because Python 3.9 has an import bug at
@@ -403,9 +406,6 @@ async def test_activity_trace_propagation(
403406
client: Client,
404407
env: WorkflowEnvironment,
405408
):
406-
# TODO: add spy interceptor to check `input.fn` wraps original metadata
407-
# TODO: Add Resource to show how resource would be propagated
408-
409409
# Create a tracer that has an in-memory exporter
410410
exporter = InMemorySpanExporter()
411411
provider = TracerProvider()
@@ -417,13 +417,16 @@ async def test_activity_trace_propagation(
417417
manager = multiprocessing.Manager()
418418
finished_spans_proxy = make_span_proxy_list(manager)
419419

420+
# Create an interceptor to test we haven't broken reflection
421+
reflection_interceptor = ReflectionInterceptor()
422+
420423
# Create a worker with a process pool activity executor
421424
async with Worker(
422425
client,
423426
task_queue=f"task_queue_{uuid.uuid4()}",
424427
workflows=[ActivityTracePropagationWorkflow],
425428
activities=[sync_activity],
426-
interceptors=[TracingInterceptor(tracer)],
429+
interceptors=[TracingInterceptor(tracer), reflection_interceptor],
427430
activity_executor=concurrent.futures.ProcessPoolExecutor(
428431
max_workers=1,
429432
initializer=activity_trace_propagation_initializer,
@@ -437,16 +440,29 @@ async def test_activity_trace_propagation(
437440
task_queue=worker.task_queue,
438441
)
439442

443+
# The dumped spans should include child spans created in the child process
440444
spans = exporter.get_finished_spans() + tuple(finished_spans_proxy)
441445
logging.debug("Spans:\n%s", "\n".join(dump_spans(spans, with_attributes=False)))
442446
assert dump_spans(spans, with_attributes=False) == [
443447
"RunActivity:sync_activity",
444448
" child_span",
445449
]
446450

451+
# and the activity should still have the original attributes in downstream interceptors
452+
assert reflection_interceptor.get_intercepted_activities() == [
453+
InterceptedActivity(
454+
class_name="ActivityFnWithTraceContext",
455+
name="sync_activity",
456+
qualname="sync_activity",
457+
module="tests.contrib.test_opentelemetry",
458+
docstring="An activity that uses tracing features.\n\nWhen executed in a process pool, we expect the trace context to be available\nfrom the parent process.\n",
459+
annotations={"param": "typing.Any", "return": "str"},
460+
)
461+
]
462+
447463

448464
def activity_trace_propagation_initializer(
449-
_finished_spans_proxy: SerialisableSpanListProxy,
465+
_finished_spans_proxy: multiprocessing.managers.ListProxy[SerialisableSpan],
450466
) -> None:
451467
"""Initializer for the process pool worker to export spans to a shared list."""
452468
_exporter = _ListProxySpanExporter(_finished_spans_proxy)

0 commit comments

Comments
 (0)