Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/gen_payload_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def emit_loop(
if not self.skip_headers:
for v in {iter_expr}:
await self._visit_{child_method}(fs, v)"""
elif field_name == "search_attributes":
return f"""\
if not self.skip_search_attributes:
for v in {iter_expr}:
await self._visit_{child_method}(fs, v)"""
else:
return f"""\
for v in {iter_expr}:
Expand Down Expand Up @@ -197,7 +202,7 @@ def walk(self, desc: Descriptor) -> bool:
# Process regular fields first
for field in regular_fields:
# Repeated fields (including maps which are represented as repeated messages)
if field.label == FieldDescriptor.LABEL_REPEATED:
if field.is_repeated:
if (
field.message_type is not None
and field.message_type.GetOptions().map_entry
Expand Down
15 changes: 9 additions & 6 deletions temporalio/bridge/_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(
if not self.skip_headers:
for v in o.headers.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)
if not self.skip_search_attributes:
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)

async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o):
await self._visit_payload_container(fs, o.input)
Expand All @@ -330,8 +331,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs,
await self._visit_temporal_api_common_v1_Payload(fs, v)
for v in o.memo.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)
if not self.skip_search_attributes:
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)

async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
self, fs, o
Expand All @@ -350,8 +352,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o):
async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(
self, fs, o
):
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)
if not self.skip_search_attributes:
for v in o.search_attributes.values():
await self._visit_temporal_api_common_v1_Payload(fs, v)

async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o):
if o.HasField("upserted_memo"):
Expand Down
107 changes: 107 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import temporalio.activity
import temporalio.api.sdk.v1
import temporalio.client
import temporalio.converter
import temporalio.worker
import temporalio.workflow
from temporalio import activity, workflow
Expand Down Expand Up @@ -8369,3 +8370,109 @@ async def test_previous_run_failure(client: Client):
)
result = await handle.result()
assert result == "Done"


class EncryptionCodec(PayloadCodec):
def __init__(
self,
key_id: str = "test-key-id",
key: bytes = b"test-key-test-key-test-key-test!",
) -> None:
super().__init__()
self.key_id = key_id

async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
# We blindly encode all payloads with the key and set the metadata
# saying which key we used
return [
Payload(
metadata={
"encoding": b"binary/encrypted",
"encryption-key-id": self.key_id.encode(),
},
data=self.encrypt(p.SerializeToString()),
)
for p in payloads
]

async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
ret: List[Payload] = []
for p in payloads:
# Ignore ones w/out our expected encoding
if p.metadata.get("encoding", b"").decode() != "binary/encrypted":
ret.append(p)
continue
# Confirm our key ID is the same
key_id = p.metadata.get("encryption-key-id", b"").decode()
if key_id != self.key_id:
raise ValueError(
f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}."
)
# Decrypt and append
ret.append(Payload.FromString(self.decrypt(p.data)))
return ret

def encrypt(self, data: bytes) -> bytes:
nonce = os.urandom(12)
return data

def decrypt(self, data: bytes) -> bytes:
return data


@workflow.defn
class SearchAttributeCodecParentWorkflow:
@workflow.run
async def run(self, name: str) -> str:
print(
await workflow.execute_child_workflow(
workflow=SearchAttributeCodecChildWorkflow.run,
arg=name,
id=f"child-{name}",
search_attributes=workflow.info().typed_search_attributes,
)
)
return f"Hello, {name}"


@workflow.defn
class SearchAttributeCodecChildWorkflow:
@workflow.run
async def run(self, name: str) -> str:
return f"Hello from child, {name}"


async def test_search_attribute_codec(client: Client, env_type: str):
if env_type != "local":
pytest.skip("Only testing search attributes on local which disables cache")
await ensure_search_attributes_present(
client,
SearchAttributeWorkflow.text_attribute,
)

config = client.config()
config["data_converter"] = dataclasses.replace(
temporalio.converter.default(), payload_codec=EncryptionCodec()
)
client = Client(**config)

# Run a worker for the workflow
async with new_worker(
client,
SearchAttributeCodecParentWorkflow,
SearchAttributeCodecChildWorkflow,
) as worker:
# Run workflow
result = await client.execute_workflow(
SearchAttributeCodecParentWorkflow.run,
"Temporal",
id=f"encryption-workflow-id",
task_queue=worker.task_queue,
search_attributes=TypedSearchAttributes(
[
SearchAttributePair(
SearchAttributeWorkflow.text_attribute, "test_text"
)
]
),
)