2020)
2121
2222import google .protobuf .internal .containers
23+ from google .protobuf .message import Message
2324from typing_extensions import TypeAlias
2425
2526import temporalio .api .common .v1
3940)
4041from temporalio .bridge .temporal_sdk_bridge import PollShutdownError # type: ignore
4142
43+ from temporalio .api .common .v1 .message_pb2 import Payload
44+ from temporalio .bridge .visitor import visit_payloads , visit_message
45+
4246
4347@dataclass
4448class WorkerConfig :
@@ -368,15 +372,9 @@ async def _encode_payloads(
368372 codec : temporalio .converter .PayloadCodec ,
369373) -> None :
370374 """Encode payloads with the given codec."""
371- return await _apply_to_payloads (payloads , codec .encode )
372-
373-
374- async def _encode_payload (
375- payload : temporalio .api .common .v1 .Payload ,
376- codec : temporalio .converter .PayloadCodec ,
377- ) -> None :
378- """Decode a payload with the given codec."""
379- return await _apply_to_payload (payload , codec .encode )
375+ async def visitor (payload : Payload ) -> Payload :
376+ return (await codec .encode ([payload ]))[0 ]
377+ return await visit_payloads (visitor , payloads )
380378
381379
382380async def decode_activation (
@@ -385,144 +383,18 @@ async def decode_activation(
385383 decode_headers : bool ,
386384) -> None :
387385 """Decode the given activation with the codec."""
388- for job in act .jobs :
389- if job .HasField ("query_workflow" ):
390- await _decode_payloads (job .query_workflow .arguments , codec )
391- if decode_headers :
392- await _decode_headers (job .query_workflow .headers , codec )
393- elif job .HasField ("resolve_activity" ):
394- if job .resolve_activity .result .HasField ("cancelled" ):
395- await codec .decode_failure (
396- job .resolve_activity .result .cancelled .failure
397- )
398- elif job .resolve_activity .result .HasField ("completed" ):
399- if job .resolve_activity .result .completed .HasField ("result" ):
400- await _decode_payload (
401- job .resolve_activity .result .completed .result , codec
402- )
403- elif job .resolve_activity .result .HasField ("failed" ):
404- await codec .decode_failure (job .resolve_activity .result .failed .failure )
405- elif job .HasField ("resolve_child_workflow_execution" ):
406- if job .resolve_child_workflow_execution .result .HasField ("cancelled" ):
407- await codec .decode_failure (
408- job .resolve_child_workflow_execution .result .cancelled .failure
409- )
410- elif job .resolve_child_workflow_execution .result .HasField (
411- "completed"
412- ) and job .resolve_child_workflow_execution .result .completed .HasField (
413- "result"
414- ):
415- await _decode_payload (
416- job .resolve_child_workflow_execution .result .completed .result , codec
417- )
418- elif job .resolve_child_workflow_execution .result .HasField ("failed" ):
419- await codec .decode_failure (
420- job .resolve_child_workflow_execution .result .failed .failure
421- )
422- elif job .HasField ("resolve_child_workflow_execution_start" ):
423- if job .resolve_child_workflow_execution_start .HasField ("cancelled" ):
424- await codec .decode_failure (
425- job .resolve_child_workflow_execution_start .cancelled .failure
426- )
427- elif job .HasField ("resolve_request_cancel_external_workflow" ):
428- if job .resolve_request_cancel_external_workflow .HasField ("failure" ):
429- await codec .decode_failure (
430- job .resolve_request_cancel_external_workflow .failure
431- )
432- elif job .HasField ("resolve_signal_external_workflow" ):
433- if job .resolve_signal_external_workflow .HasField ("failure" ):
434- await codec .decode_failure (job .resolve_signal_external_workflow .failure )
435- elif job .HasField ("signal_workflow" ):
436- await _decode_payloads (job .signal_workflow .input , codec )
437- if decode_headers :
438- await _decode_headers (job .signal_workflow .headers , codec )
439- elif job .HasField ("initialize_workflow" ):
440- await _decode_payloads (job .initialize_workflow .arguments , codec )
441- if decode_headers :
442- await _decode_headers (job .initialize_workflow .headers , codec )
443- if job .initialize_workflow .HasField ("continued_failure" ):
444- await codec .decode_failure (job .initialize_workflow .continued_failure )
445- for val in job .initialize_workflow .memo .fields .values ():
446- # This uses API payload not bridge payload
447- new_payload = (await codec .decode ([val ]))[0 ]
448- # Make a shallow copy, in case new_payload.metadata and val.metadata are
449- # references to the same memory, e.g. decode() returns its input unchanged.
450- new_metadata = dict (new_payload .metadata )
451- val .metadata .clear ()
452- val .metadata .update (new_metadata )
453- val .data = new_payload .data
454- elif job .HasField ("do_update" ):
455- await _decode_payloads (job .do_update .input , codec )
456- if decode_headers :
457- await _decode_headers (job .do_update .headers , codec )
386+ async def visitor (payload : Payload ) -> Payload :
387+ return (await codec .decode ([payload ]))[0 ]
458388
389+ await visit_message (visitor , act )
459390
460391async def encode_completion (
461392 comp : temporalio .bridge .proto .workflow_completion .WorkflowActivationCompletion ,
462393 codec : temporalio .converter .PayloadCodec ,
463394 encode_headers : bool ,
464395) -> None :
465396 """Recursively encode the given completion with the codec."""
466- if comp .HasField ("failed" ):
467- await codec .encode_failure (comp .failed .failure )
468- elif comp .HasField ("successful" ):
469- for command in comp .successful .commands :
470- if command .HasField ("complete_workflow_execution" ):
471- if command .complete_workflow_execution .HasField ("result" ):
472- await _encode_payload (
473- command .complete_workflow_execution .result , codec
474- )
475- elif command .HasField ("continue_as_new_workflow_execution" ):
476- await _encode_payloads (
477- command .continue_as_new_workflow_execution .arguments , codec
478- )
479- if encode_headers :
480- await _encode_headers (
481- command .continue_as_new_workflow_execution .headers , codec
482- )
483- for val in command .continue_as_new_workflow_execution .memo .values ():
484- await _encode_payload (val , codec )
485- elif command .HasField ("fail_workflow_execution" ):
486- await codec .encode_failure (command .fail_workflow_execution .failure )
487- elif command .HasField ("respond_to_query" ):
488- if command .respond_to_query .HasField ("failed" ):
489- await codec .encode_failure (command .respond_to_query .failed )
490- elif command .respond_to_query .HasField (
491- "succeeded"
492- ) and command .respond_to_query .succeeded .HasField ("response" ):
493- await _encode_payload (
494- command .respond_to_query .succeeded .response , codec
495- )
496- elif command .HasField ("schedule_activity" ):
497- await _encode_payloads (command .schedule_activity .arguments , codec )
498- if encode_headers :
499- await _encode_headers (command .schedule_activity .headers , codec )
500- elif command .HasField ("schedule_local_activity" ):
501- await _encode_payloads (command .schedule_local_activity .arguments , codec )
502- if encode_headers :
503- await _encode_headers (
504- command .schedule_local_activity .headers , codec
505- )
506- elif command .HasField ("signal_external_workflow_execution" ):
507- await _encode_payloads (
508- command .signal_external_workflow_execution .args , codec
509- )
510- if encode_headers :
511- await _encode_headers (
512- command .signal_external_workflow_execution .headers , codec
513- )
514- elif command .HasField ("start_child_workflow_execution" ):
515- await _encode_payloads (
516- command .start_child_workflow_execution .input , codec
517- )
518- if encode_headers :
519- await _encode_headers (
520- command .start_child_workflow_execution .headers , codec
521- )
522- for val in command .start_child_workflow_execution .memo .values ():
523- await _encode_payload (val , codec )
524- elif command .HasField ("update_response" ):
525- if command .update_response .HasField ("completed" ):
526- await _encode_payload (command .update_response .completed , codec )
527- elif command .update_response .HasField ("rejected" ):
528- await codec .encode_failure (command .update_response .rejected )
397+ async def visitor (payload : Payload ) -> Payload :
398+ return (await codec .encode ([payload ]))[0 ]
399+
400+ await visit_message (visitor , comp )
0 commit comments