Skip to content

Commit 2efd9a7

Browse files
Add start_update_with_start_workflow to Otel Interceptor (#1150)
* initial draft of fix * remove debug printing * update comment about fan out to specify operation rather than command. Restore existing tracing test to original state * remove copy/pasted todo * Clean up test a little bit * move header fan out outside of try block * Revert changes to client as they are unecessary. Inject otel headers into both operation inputs in otel interceptor * set otel header value directly in update_workflow_input to avoid the extra call to payload conversion
1 parent 2ad41ab commit 2efd9a7

File tree

2 files changed

+129
-2
lines changed

2 files changed

+129
-2
lines changed

temporalio/contrib/opentelemetry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,30 @@ async def start_workflow_update(
292292
):
293293
return await super().start_workflow_update(input)
294294

295+
async def start_update_with_start_workflow(
296+
self, input: temporalio.client.StartWorkflowUpdateWithStartInput
297+
) -> temporalio.client.WorkflowUpdateHandle[Any]:
298+
attrs = {
299+
"temporalWorkflowID": input.start_workflow_input.id,
300+
}
301+
if input.update_workflow_input.update_id is not None:
302+
attrs["temporalUpdateID"] = input.update_workflow_input.update_id
303+
304+
with self.root._start_as_current_span(
305+
f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}",
306+
attributes=attrs,
307+
input=input.start_workflow_input,
308+
kind=opentelemetry.trace.SpanKind.CLIENT,
309+
):
310+
otel_header = input.start_workflow_input.headers.get(self.root.header_key)
311+
if otel_header:
312+
input.update_workflow_input.headers = {
313+
**input.update_workflow_input.headers,
314+
self.root.header_key: otel_header,
315+
}
316+
317+
return await super().start_update_with_start_workflow(input)
318+
295319

296320
class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
297321
def __init__(

tests/contrib/test_opentelemetry.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from opentelemetry.trace import StatusCode, get_tracer
1515

1616
from temporalio import activity, workflow
17-
from temporalio.client import Client
18-
from temporalio.common import RetryPolicy
17+
from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage
18+
from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy
1919
from temporalio.contrib.opentelemetry import TracingInterceptor
2020
from temporalio.contrib.opentelemetry import workflow as otel_workflow
2121
from temporalio.exceptions import ApplicationError, ApplicationErrorCategory
@@ -55,6 +55,7 @@ class TracingWorkflowAction:
5555
continue_as_new: Optional[TracingWorkflowActionContinueAsNew] = None
5656
wait_until_signal_count: int = 0
5757
wait_and_do_update: bool = False
58+
wait_and_do_start_with_update: bool = False
5859

5960

6061
@dataclass
@@ -79,13 +80,15 @@ class TracingWorkflowActionContinueAsNew:
7980

8081

8182
ready_for_update: asyncio.Semaphore
83+
ready_for_update_with_start: asyncio.Semaphore
8284

8385

8486
@workflow.defn
8587
class TracingWorkflow:
8688
def __init__(self) -> None:
8789
self._signal_count = 0
8890
self._did_update = False
91+
self._did_update_with_start = False
8992

9093
@workflow.run
9194
async def run(self, param: TracingWorkflowParam) -> None:
@@ -140,6 +143,9 @@ async def run(self, param: TracingWorkflowParam) -> None:
140143
if action.wait_and_do_update:
141144
ready_for_update.release()
142145
await workflow.wait_condition(lambda: self._did_update)
146+
if action.wait_and_do_start_with_update:
147+
ready_for_update_with_start.release()
148+
await workflow.wait_condition(lambda: self._did_update_with_start)
143149

144150
async def _raise_on_non_replay(self) -> None:
145151
replaying = workflow.unsafe.is_replaying()
@@ -161,6 +167,10 @@ def signal(self) -> None:
161167
def update(self) -> None:
162168
self._did_update = True
163169

170+
@workflow.update
171+
def update_with_start(self) -> None:
172+
self._did_update_with_start = True
173+
164174
@update.validator
165175
def update_validator(self) -> None:
166176
pass
@@ -301,6 +311,99 @@ async def test_opentelemetry_tracing(client: Client, env: WorkflowEnvironment):
301311
]
302312

303313

314+
async def test_opentelemetry_tracing_update_with_start(
315+
client: Client, env: WorkflowEnvironment
316+
):
317+
if env.supports_time_skipping:
318+
pytest.skip(
319+
"Java test server: https://github.com/temporalio/sdk-java/issues/1424"
320+
)
321+
global ready_for_update_with_start
322+
ready_for_update_with_start = asyncio.Semaphore(0)
323+
# Create a tracer that has an in-memory exporter
324+
exporter = InMemorySpanExporter()
325+
provider = TracerProvider()
326+
provider.add_span_processor(SimpleSpanProcessor(exporter))
327+
tracer = get_tracer(__name__, tracer_provider=provider)
328+
# Create new client with tracer interceptor
329+
client_config = client.config()
330+
client_config["interceptors"] = [TracingInterceptor(tracer)]
331+
client = Client(**client_config)
332+
333+
task_queue = f"task_queue_{uuid.uuid4()}"
334+
async with Worker(
335+
client,
336+
task_queue=task_queue,
337+
workflows=[TracingWorkflow],
338+
activities=[tracing_activity],
339+
# Needed so we can wait to send update at the right time
340+
workflow_runner=UnsandboxedWorkflowRunner(),
341+
):
342+
# Run workflow with various actions
343+
workflow_id = f"workflow_{uuid.uuid4()}"
344+
workflow_params = TracingWorkflowParam(
345+
actions=[
346+
# Wait for update
347+
TracingWorkflowAction(wait_and_do_start_with_update=True),
348+
]
349+
)
350+
handle = await client.start_workflow(
351+
TracingWorkflow.run,
352+
workflow_params,
353+
id=workflow_id,
354+
task_queue=task_queue,
355+
)
356+
async with ready_for_update_with_start:
357+
start_op = WithStartWorkflowOperation(
358+
TracingWorkflow.run,
359+
workflow_params,
360+
id=handle.id,
361+
task_queue=task_queue,
362+
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
363+
)
364+
await client.start_update_with_start_workflow(
365+
TracingWorkflow.update_with_start,
366+
start_workflow_operation=start_op,
367+
id=handle.id,
368+
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
369+
)
370+
await handle.result()
371+
372+
# issue update with start again to trigger a new workflow
373+
workflow_id = f"workflow_{uuid.uuid4()}"
374+
start_op = WithStartWorkflowOperation(
375+
TracingWorkflow.run,
376+
TracingWorkflowParam(actions=[]),
377+
id=workflow_id,
378+
task_queue=task_queue,
379+
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
380+
)
381+
await client.execute_update_with_start_workflow(
382+
update=TracingWorkflow.update_with_start,
383+
start_workflow_operation=start_op,
384+
id=workflow_id,
385+
)
386+
387+
# Dump debug with attributes, but do string assertion test without
388+
logging.debug(
389+
"Spans:\n%s",
390+
"\n".join(dump_spans(exporter.get_finished_spans(), with_attributes=False)),
391+
)
392+
assert dump_spans(exporter.get_finished_spans(), with_attributes=False) == [
393+
"StartWorkflow:TracingWorkflow",
394+
" RunWorkflow:TracingWorkflow",
395+
" MyCustomSpan",
396+
" HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)",
397+
" CompleteWorkflow:TracingWorkflow",
398+
"StartUpdateWithStartWorkflow:TracingWorkflow",
399+
"StartUpdateWithStartWorkflow:TracingWorkflow",
400+
" HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)",
401+
" RunWorkflow:TracingWorkflow",
402+
" MyCustomSpan",
403+
" CompleteWorkflow:TracingWorkflow",
404+
]
405+
406+
304407
def dump_spans(
305408
spans: Iterable[ReadableSpan],
306409
*,

0 commit comments

Comments
 (0)