Skip to content

Commit 2abe719

Browse files
committed
Add test
1 parent 8af3aa1 commit 2abe719

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

tests/worker/test_workflow.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8369,3 +8369,94 @@ async def test_previous_run_failure(client: Client):
83698369
)
83708370
result = await handle.result()
83718371
assert result == "Done"
8372+
8373+
class EncryptionCodec(PayloadCodec):
8374+
def __init__(self, key_id: str = "test-key-id", key: bytes = b"test-key-test-key-test-key-test!") -> None:
8375+
super().__init__()
8376+
self.key_id = key_id
8377+
8378+
async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
8379+
# We blindly encode all payloads with the key and set the metadata
8380+
# saying which key we used
8381+
return [
8382+
Payload(
8383+
metadata={
8384+
"encoding": b"binary/encrypted",
8385+
"encryption-key-id": self.key_id.encode(),
8386+
},
8387+
data=self.encrypt(p.SerializeToString()),
8388+
)
8389+
for p in payloads
8390+
]
8391+
8392+
async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
8393+
ret: List[Payload] = []
8394+
for p in payloads:
8395+
# Ignore ones w/out our expected encoding
8396+
if p.metadata.get("encoding", b"").decode() != "binary/encrypted":
8397+
ret.append(p)
8398+
continue
8399+
# Confirm our key ID is the same
8400+
key_id = p.metadata.get("encryption-key-id", b"").decode()
8401+
if key_id != self.key_id:
8402+
raise ValueError(f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}.")
8403+
# Decrypt and append
8404+
ret.append(Payload.FromString(self.decrypt(p.data)))
8405+
return ret
8406+
8407+
def encrypt(self, data: bytes) -> bytes:
8408+
nonce = os.urandom(12)
8409+
return data
8410+
8411+
def decrypt(self, data: bytes) -> bytes:
8412+
return data
8413+
8414+
8415+
@workflow.defn
8416+
class SearchAttributeCodecParentWorkflow:
8417+
text_attribute = SearchAttributeKey.for_text(f"text_sa")
8418+
8419+
@workflow.run
8420+
async def run(self, name: str) -> str:
8421+
print(
8422+
await workflow.execute_child_workflow(
8423+
workflow=SearchAttributeCodecChildWorkflow.run,
8424+
arg=name,
8425+
id=f"child-{name}",
8426+
search_attributes=workflow.info().typed_search_attributes,
8427+
)
8428+
)
8429+
return f"Hello, {name}"
8430+
8431+
8432+
@workflow.defn
8433+
class SearchAttributeCodecChildWorkflow:
8434+
@workflow.run
8435+
async def run(self, name: str) -> str:
8436+
return f"Hello from child, {name}"
8437+
8438+
async def test_search_attribute_codec(client: Client):
8439+
await ensure_search_attributes_present(
8440+
client,
8441+
SearchAttributeCodecParentWorkflow.text_attribute,
8442+
)
8443+
8444+
config = client.config()
8445+
config["data_converter"] = dataclasses.replace(temporalio.converter.default(), payload_codec=EncryptionCodec())
8446+
client = Client(**config)
8447+
8448+
# Run a worker for the workflow
8449+
async with new_worker(
8450+
client,
8451+
SearchAttributeCodecParentWorkflow, SearchAttributeCodecChildWorkflow,
8452+
) as worker:
8453+
# Run workflow
8454+
result = await client.execute_workflow(
8455+
SearchAttributeCodecParentWorkflow.run,
8456+
"Temporal",
8457+
id=f"encryption-workflow-id",
8458+
task_queue=worker.task_queue,
8459+
search_attributes=TypedSearchAttributes(
8460+
[SearchAttributePair(SearchAttributeCodecParentWorkflow.text_attribute, "test_text")]
8461+
),
8462+
)

0 commit comments

Comments
 (0)