Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ temporalio/bridge/temporal_sdk_bridge*
/sdk-python.iml
/.zed
*.DS_Store
tags
49 changes: 27 additions & 22 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,34 @@ def _start_as_current_span(
kind: opentelemetry.trace.SpanKind,
context: Optional[Context] = None,
) -> Iterator[None]:
with self.tracer.start_as_current_span(
name,
attributes=attributes,
kind=kind,
context=context,
set_status_on_exception=False,
) as span:
if input:
input.headers = self._context_to_headers(input.headers)
try:
yield None
except Exception as exc:
if (
not isinstance(exc, ApplicationError)
or exc.category != ApplicationErrorCategory.BENIGN
):
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
token = opentelemetry.context.attach(context) if context else None
try:
with self.tracer.start_as_current_span(
name,
attributes=attributes,
kind=kind,
context=context,
set_status_on_exception=False,
) as span:
if input:
input.headers = self._context_to_headers(input.headers)
try:
yield None
except Exception as exc:
if (
not isinstance(exc, ApplicationError)
or exc.category != ApplicationErrorCategory.BENIGN
):
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
)
raise
raise
finally:
if token and context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)

def _completed_workflow_span(
self, params: _CompletedWorkflowSpanParams
Expand Down
287 changes: 286 additions & 1 deletion tests/contrib/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterable, List, Optional
from typing import Callable, Dict, Generator, Iterable, List, Optional, cast

import opentelemetry.context
import pytest
from opentelemetry import baggage, context
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
Expand Down Expand Up @@ -558,6 +560,289 @@ async def test_opentelemetry_benign_exception(client: Client):
assert all(span.status.status_code == StatusCode.UNSET for span in spans)


@contextmanager
def baggage_values(values: Dict[str, str]) -> Generator[None, None, None]:
ctx = context.get_current()
for key, value in values.items():
ctx = baggage.set_baggage(key, value, context=ctx)

token = context.attach(ctx)
try:
yield
finally:
context.detach(token)


@pytest.fixture
def client_with_tracing(client: Client) -> Client:
tracer = get_tracer(__name__, tracer_provider=TracerProvider())
client_config = client.config()
client_config["interceptors"] = [TracingInterceptor(tracer)]
return Client(**client_config)


def get_baggage_value(key: str) -> str:
return cast("str", baggage.get_baggage(key))


@activity.defn
async def read_baggage_activity() -> Dict[str, str]:
return {
"user_id": get_baggage_value("user.id"),
"tenant_id": get_baggage_value("tenant.id"),
}


@workflow.defn
class ReadBaggageTestWorkflow:
@workflow.run
async def run(self) -> Dict[str, str]:
return await workflow.execute_activity(
read_baggage_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_baggage_propagation_basic(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[ReadBaggageTestWorkflow],
activities=[read_baggage_activity],
):
with baggage_values({"user.id": "test-user-123", "tenant.id": "some-corp"}):
result = await client_with_tracing.execute_workflow(
ReadBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert (
result["user_id"] == "test-user-123"
), "user.id baggage should propagate to activity"
assert (
result["tenant_id"] == "some-corp"
), "tenant.id baggage should propagate to activity"


@activity.defn
async def read_baggage_local_activity() -> Dict[str, str]:
return cast(
Dict[str, str],
{
"user_id": get_baggage_value("user.id"),
"tenant_id": get_baggage_value("tenant.id"),
},
)


@workflow.defn
class LocalActivityBaggageTestWorkflow:
@workflow.run
async def run(self) -> Dict[str, str]:
return await workflow.execute_local_activity(
read_baggage_local_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_baggage_propagation_local_activity(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[LocalActivityBaggageTestWorkflow],
activities=[read_baggage_local_activity],
):
with baggage_values(
{
"user.id": "test-user-456",
"tenant.id": "local-corp",
}
):
result = await client_with_tracing.execute_workflow(
LocalActivityBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert result["user_id"] == "test-user-456"
assert result["tenant_id"] == "local-corp"


retry_attempt_baggage_values: List[str] = []


@activity.defn
async def failing_baggage_activity() -> None:
retry_attempt_baggage_values.append(get_baggage_value("user.id"))
if activity.info().attempt < 2:
raise RuntimeError("Intentional failure")


@workflow.defn
class RetryBaggageTestWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
failing_baggage_activity,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(initial_interval=timedelta(milliseconds=1)),
)


async def test_opentelemetry_baggage_propagation_with_retries(
client_with_tracing: Client, env: WorkflowEnvironment
) -> None:
global retry_attempt_baggage_values
retry_attempt_baggage_values = []

task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[RetryBaggageTestWorkflow],
activities=[failing_baggage_activity],
):
with baggage_values({"user.id": "test-user-retry"}):
await client_with_tracing.execute_workflow(
RetryBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

# Verify baggage was present on all attempts
assert len(retry_attempt_baggage_values) == 2
assert all(v == "test-user-retry" for v in retry_attempt_baggage_values)


@activity.defn
async def context_clear_noop_activity() -> None:
pass


@activity.defn
async def context_clear_exception_activity() -> None:
raise Exception("Simulated exception")


@workflow.defn
class ContextClearWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
context_clear_noop_activity,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(
maximum_attempts=1, initial_interval=timedelta(milliseconds=1)
),
)


EXPECT_FAILURE = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the deal with this global? Are you just trying to put True and False in the parameters? If that's the case, just do that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is just a way to make parameters in parametrized test more readable. To have this

    [
        (context_clear_noop_activity, not EXPECT_FAILURE),
        (context_clear_exception_activity, EXPECT_FAILURE),
    ],

instead of this

    [
        (context_clear_noop_activity, not True),
        (context_clear_exception_activity, True),
    ],

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively I can create an enum or even a dataclass for all parameters. I decided againts that as that is move verbose with very little gain IMHO



@pytest.mark.parametrize(
"activity,expect_failure",
[
(context_clear_noop_activity, not EXPECT_FAILURE),
(context_clear_exception_activity, EXPECT_FAILURE),
],
)
async def test_opentelemetry_context_restored_after_activity(
client_with_tracing: Client,
env: WorkflowEnvironment,
activity: Callable[[], None],
expect_failure: bool,
) -> None:
attach_count = 0
detach_count = 0
original_attach = context.attach
original_detach = context.detach

def tracked_attach(ctx):
nonlocal attach_count
attach_count += 1
return original_attach(ctx)

def tracked_detach(token):
nonlocal detach_count
detach_count += 1
return original_detach(token)

context.attach = tracked_attach
context.detach = tracked_detach

try:
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[ContextClearWorkflow],
activities=[activity],
):
with baggage_values({"user.id": "test-123"}):
try:
await client_with_tracing.execute_workflow(
ContextClearWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)
assert (
not expect_failure
), "This test should have raised an exception"
except Exception:
assert expect_failure, "This test is not expeced to raise"

assert (
attach_count == detach_count
), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. "
assert attach_count > 0, "Expected at least one context attach/detach"

finally:
context.attach = original_attach
context.detach = original_detach


@activity.defn
async def simple_no_context_activity() -> str:
return "success"


@workflow.defn
class SimpleNoContextWorkflow:
@workflow.run
async def run(self) -> str:
return await workflow.execute_activity(
simple_no_context_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_interceptor_works_if_no_context(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[SimpleNoContextWorkflow],
activities=[simple_no_context_activity],
):
result = await client_with_tracing.execute_workflow(
SimpleNoContextWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert result == "success"


# TODO(cretz): Additional tests to write
# * query without interceptor (no headers)
# * workflow without interceptor (no headers) but query with interceptor (headers)
Expand Down
Loading