Skip to content

Commit e22c625

Browse files
committed
wire activity context in worker
1 parent 6839ae4 commit e22c625

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

temporalio/worker/_activity.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,26 @@ async def _heartbeat_async(
252252
if details is None:
253253
return
254254

255+
data_converter = self._data_converter
256+
if activity.info:
257+
context = temporalio.converter.ActivitySerializationContext(
258+
namespace=activity.info.workflow_namespace,
259+
workflow_id=activity.info.workflow_id,
260+
workflow_type=activity.info.workflow_type,
261+
activity_type=activity.info.activity_type,
262+
activity_task_queue=self._task_queue,
263+
is_local=activity.info.is_local,
264+
)
265+
data_converter = data_converter._with_context(context)
266+
255267
# Perform the heartbeat
256268
try:
257269
heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue]
258270
task_token=task_token
259271
)
260272
if details:
261273
# Convert to core payloads
262-
heartbeat.details.extend(await self._data_converter.encode(details))
274+
heartbeat.details.extend(await data_converter.encode(details))
263275
logger.debug("Recording heartbeat with details %s", details)
264276
self._bridge_worker().record_activity_heartbeat(heartbeat)
265277
except Exception as err:
@@ -293,9 +305,21 @@ async def _handle_start_activity_task(
293305
completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue]
294306
task_token=task_token
295307
)
308+
# Create serialization context for the activity
309+
context = temporalio.converter.ActivitySerializationContext(
310+
namespace=start.workflow_namespace,
311+
workflow_id=start.workflow_execution.workflow_id,
312+
workflow_type=start.workflow_type,
313+
activity_type=start.activity_type,
314+
activity_task_queue=self._task_queue,
315+
is_local=start.is_local,
316+
)
317+
data_converter = self._data_converter._with_context(context)
296318
try:
297-
result = await self._execute_activity(start, running_activity, task_token)
298-
[payload] = await self._data_converter.encode([result])
319+
result = await self._execute_activity(
320+
start, running_activity, task_token, data_converter
321+
)
322+
[payload] = await data_converter.encode([result])
299323
completion.result.completed.result.CopyFrom(payload)
300324
except BaseException as err:
301325
try:
@@ -313,7 +337,7 @@ async def _handle_start_activity_task(
313337
temporalio.activity.logger.warning(
314338
f"Completing as failure during heartbeat with error of type {type(err)}: {err}",
315339
)
316-
await self._data_converter.encode_failure(
340+
await data_converter.encode_failure(
317341
err, completion.result.failed.failure
318342
)
319343
elif (
@@ -327,7 +351,7 @@ async def _handle_start_activity_task(
327351
temporalio.activity.logger.warning(
328352
"Completing as failure due to unhandled cancel error produced by activity pause",
329353
)
330-
await self._data_converter.encode_failure(
354+
await data_converter.encode_failure(
331355
temporalio.exceptions.ApplicationError(
332356
type="ActivityPause",
333357
message="Unhandled activity cancel error produced by activity pause",
@@ -345,7 +369,7 @@ async def _handle_start_activity_task(
345369
temporalio.activity.logger.warning(
346370
"Completing as failure due to unhandled cancel error produced by activity reset",
347371
)
348-
await self._data_converter.encode_failure(
372+
await data_converter.encode_failure(
349373
temporalio.exceptions.ApplicationError(
350374
type="ActivityReset",
351375
message="Unhandled activity cancel error produced by activity reset",
@@ -360,7 +384,7 @@ async def _handle_start_activity_task(
360384
and running_activity.cancelled_by_request
361385
):
362386
temporalio.activity.logger.debug("Completing as cancelled")
363-
await self._data_converter.encode_failure(
387+
await data_converter.encode_failure(
364388
# TODO(cretz): Should use some other message?
365389
temporalio.exceptions.CancelledError("Cancelled"),
366390
completion.result.cancelled.failure,
@@ -386,7 +410,7 @@ async def _handle_start_activity_task(
386410
exc_info=True,
387411
extra={"__temporal_error_identifier": "ActivityFailure"},
388412
)
389-
await self._data_converter.encode_failure(
413+
await data_converter.encode_failure(
390414
err, completion.result.failed.failure
391415
)
392416
# For broken executors, we have to fail the entire worker
@@ -428,6 +452,7 @@ async def _execute_activity(
428452
start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue]
429453
running_activity: _RunningActivity,
430454
task_token: bytes,
455+
data_converter: temporalio.converter.DataConverter,
431456
) -> Any:
432457
"""Invoke the user's activity function.
433458
@@ -501,9 +526,7 @@ async def _execute_activity(
501526
args = (
502527
[]
503528
if not start.input
504-
else await self._data_converter.decode(
505-
start.input, type_hints=arg_types
506-
)
529+
else await data_converter.decode(start.input, type_hints=arg_types)
507530
)
508531
except Exception as err:
509532
raise temporalio.exceptions.ApplicationError(
@@ -519,7 +542,7 @@ async def _execute_activity(
519542
heartbeat_details = (
520543
[]
521544
if not start.heartbeat_details
522-
else await self._data_converter.decode(start.heartbeat_details)
545+
else await data_converter.decode(start.heartbeat_details)
523546
)
524547
except Exception as err:
525548
raise temporalio.exceptions.ApplicationError(
@@ -563,11 +586,9 @@ async def _execute_activity(
563586
else None,
564587
)
565588

566-
if self._encode_headers and self._data_converter.payload_codec is not None:
589+
if self._encode_headers and data_converter.payload_codec is not None:
567590
for payload in start.header_fields.values():
568-
new_payload = (
569-
await self._data_converter.payload_codec.decode([payload])
570-
)[0]
591+
new_payload = (await data_converter.payload_codec.decode([payload]))[0]
571592
payload.CopyFrom(new_payload)
572593

573594
running_activity.info = info
@@ -591,7 +612,7 @@ async def _execute_activity(
591612
if not running_activity.cancel_thread_raiser
592613
else running_activity.cancel_thread_raiser.shielded
593614
),
594-
payload_converter_class_or_instance=self._data_converter.payload_converter,
615+
payload_converter_class_or_instance=data_converter.payload_converter,
595616
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
596617
client=self._client if not running_activity.sync else None,
597618
cancellation_details=running_activity.cancellation_details,

0 commit comments

Comments
 (0)