diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index f35f41d71..790169afd 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -32,6 +32,11 @@ def emit_loop( if not self.skip_headers: for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" + elif field_name == "search_attributes": + return f"""\ + if not self.skip_search_attributes: + for v in {iter_expr}: + await self._visit_{child_method}(fs, v)""" else: return f"""\ for v in {iter_expr}: @@ -197,7 +202,7 @@ def walk(self, desc: Descriptor) -> bool: # Process regular fields first for field in regular_fields: # Repeated fields (including maps which are represented as repeated messages) - if field.label == FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: if ( field.message_type is not None and field.message_type.GetOptions().map_entry diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index c7e38af37..0491b5e88 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -320,8 +320,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(fs, v) + if not self.skip_search_attributes: + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): await self._visit_payload_container(fs, o.input) @@ -330,8 +331,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, await self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.memo.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(fs, v) + if not self.skip_search_attributes: + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( self, fs, o @@ -350,8 +352,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( self, fs, o ): - for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(fs, v) + if not self.skip_search_attributes: + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): if o.HasField("upserted_memo"): diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 4ca9890c2..9661ad7cc 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -44,6 +44,7 @@ import temporalio.activity import temporalio.api.sdk.v1 import temporalio.client +import temporalio.converter import temporalio.worker import temporalio.workflow from temporalio import activity, workflow @@ -8369,3 +8370,109 @@ async def test_previous_run_failure(client: Client): ) result = await handle.result() assert result == "Done" + + +class EncryptionCodec(PayloadCodec): + def __init__( + self, + key_id: str = "test-key-id", + key: bytes = b"test-key-test-key-test-key-test!", + ) -> None: + super().__init__() + self.key_id = key_id + + async def encode(self, payloads: Sequence[Payload]) -> List[Payload]: + # We blindly encode all payloads with the key and set the metadata + # saying which key we used + return [ + Payload( + metadata={ + "encoding": b"binary/encrypted", + "encryption-key-id": self.key_id.encode(), + }, + data=self.encrypt(p.SerializeToString()), + ) + for p in payloads + ] + + async def decode(self, payloads: Sequence[Payload]) -> List[Payload]: + ret: List[Payload] = [] + for p in payloads: + # Ignore ones w/out our expected encoding + if p.metadata.get("encoding", b"").decode() != "binary/encrypted": + ret.append(p) + continue + # Confirm our key ID is the same + key_id = p.metadata.get("encryption-key-id", b"").decode() + if key_id != self.key_id: + raise ValueError( + f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}." + ) + # Decrypt and append + ret.append(Payload.FromString(self.decrypt(p.data))) + return ret + + def encrypt(self, data: bytes) -> bytes: + nonce = os.urandom(12) + return data + + def decrypt(self, data: bytes) -> bytes: + return data + + +@workflow.defn +class SearchAttributeCodecParentWorkflow: + @workflow.run + async def run(self, name: str) -> str: + print( + await workflow.execute_child_workflow( + workflow=SearchAttributeCodecChildWorkflow.run, + arg=name, + id=f"child-{name}", + search_attributes=workflow.info().typed_search_attributes, + ) + ) + return f"Hello, {name}" + + +@workflow.defn +class SearchAttributeCodecChildWorkflow: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello from child, {name}" + + +async def test_search_attribute_codec(client: Client, env_type: str): + if env_type != "local": + pytest.skip("Only testing search attributes on local which disables cache") + await ensure_search_attributes_present( + client, + SearchAttributeWorkflow.text_attribute, + ) + + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), payload_codec=EncryptionCodec() + ) + client = Client(**config) + + # Run a worker for the workflow + async with new_worker( + client, + SearchAttributeCodecParentWorkflow, + SearchAttributeCodecChildWorkflow, + ) as worker: + # Run workflow + result = await client.execute_workflow( + SearchAttributeCodecParentWorkflow.run, + "Temporal", + id=f"encryption-workflow-id", + task_queue=worker.task_queue, + search_attributes=TypedSearchAttributes( + [ + SearchAttributePair( + SearchAttributeWorkflow.text_attribute, "test_text" + ) + ] + ), + )