Skip to content

Commit 164b5df

Browse files
committed
Fix search attribute skipping on protos which don't use the SearchAttributes message type
1 parent ed1f6ca commit 164b5df

File tree

3 files changed

+100
-7
lines changed

3 files changed

+100
-7
lines changed

scripts/gen_payload_visitor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def emit_loop(
3232
if not self.skip_headers:
3333
for v in {iter_expr}:
3434
await self._visit_{child_method}(fs, v)"""
35+
elif field_name == "search_attributes":
36+
return f"""\
37+
if not self.skip_search_attributes:
38+
for v in {iter_expr}:
39+
await self._visit_{child_method}(fs, v)"""
3540
else:
3641
return f"""\
3742
for v in {iter_expr}:
@@ -197,7 +202,7 @@ def walk(self, desc: Descriptor) -> bool:
197202
# Process regular fields first
198203
for field in regular_fields:
199204
# Repeated fields (including maps which are represented as repeated messages)
200-
if field.label == FieldDescriptor.LABEL_REPEATED:
205+
if field.is_repeated:
201206
if (
202207
field.message_type is not None
203208
and field.message_type.GetOptions().map_entry

temporalio/bridge/_visitor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(
320320
if not self.skip_headers:
321321
for v in o.headers.values():
322322
await self._visit_temporal_api_common_v1_Payload(fs, v)
323-
for v in o.search_attributes.values():
324-
await self._visit_temporal_api_common_v1_Payload(fs, v)
323+
if not self.skip_search_attributes:
324+
for v in o.search_attributes.values():
325+
await self._visit_temporal_api_common_v1_Payload(fs, v)
325326

326327
async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o):
327328
await self._visit_payload_container(fs, o.input)
@@ -330,8 +331,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs,
330331
await self._visit_temporal_api_common_v1_Payload(fs, v)
331332
for v in o.memo.values():
332333
await self._visit_temporal_api_common_v1_Payload(fs, v)
333-
for v in o.search_attributes.values():
334-
await self._visit_temporal_api_common_v1_Payload(fs, v)
334+
if not self.skip_search_attributes:
335+
for v in o.search_attributes.values():
336+
await self._visit_temporal_api_common_v1_Payload(fs, v)
335337

336338
async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
337339
self, fs, o
@@ -350,8 +352,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o):
350352
async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(
351353
self, fs, o
352354
):
353-
for v in o.search_attributes.values():
354-
await self._visit_temporal_api_common_v1_Payload(fs, v)
355+
if not self.skip_search_attributes:
356+
for v in o.search_attributes.values():
357+
await self._visit_temporal_api_common_v1_Payload(fs, v)
355358

356359
async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o):
357360
if o.HasField("upserted_memo"):

tests/worker/test_workflow.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8369,3 +8369,88 @@ async def test_previous_run_failure(client: Client):
83698369
)
83708370
result = await handle.result()
83718371
assert result == "Done"
8372+
8373+
class EncryptionCodec(PayloadCodec):
8374+
def __init__(self, key_id: str = "test-key-id", key: bytes = b"test-key-test-key-test-key-test!") -> None:
8375+
super().__init__()
8376+
self.key_id = key_id
8377+
8378+
async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
8379+
# We blindly encode all payloads with the key and set the metadata
8380+
# saying which key we used
8381+
return [
8382+
Payload(
8383+
metadata={
8384+
"encoding": b"binary/encrypted",
8385+
"encryption-key-id": self.key_id.encode(),
8386+
},
8387+
data=self.encrypt(p.SerializeToString()),
8388+
)
8389+
for p in payloads
8390+
]
8391+
8392+
async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
8393+
ret: List[Payload] = []
8394+
for p in payloads:
8395+
# Ignore ones w/out our expected encoding
8396+
if p.metadata.get("encoding", b"").decode() != "binary/encrypted":
8397+
ret.append(p)
8398+
continue
8399+
# Confirm our key ID is the same
8400+
key_id = p.metadata.get("encryption-key-id", b"").decode()
8401+
if key_id != self.key_id:
8402+
raise ValueError(f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}.")
8403+
# Decrypt and append
8404+
ret.append(Payload.FromString(self.decrypt(p.data)))
8405+
return ret
8406+
8407+
def encrypt(self, data: bytes) -> bytes:
8408+
nonce = os.urandom(12)
8409+
return data
8410+
8411+
def decrypt(self, data: bytes) -> bytes:
8412+
return data
8413+
8414+
8415+
@workflow.defn(name="Workflow")
8416+
class GreetingWorkflow:
8417+
@workflow.run
8418+
async def run(self, name: str) -> str:
8419+
print(
8420+
await workflow.execute_child_workflow(
8421+
workflow=ChildWorkflow.run,
8422+
arg=name,
8423+
id=f"child-{name}",
8424+
search_attributes=workflow.info().typed_search_attributes,
8425+
)
8426+
)
8427+
return f"Hello, {name}"
8428+
8429+
8430+
@workflow.defn(name="ChildWorkflow")
8431+
class ChildWorkflow:
8432+
@workflow.run
8433+
async def run(self, name: str) -> str:
8434+
return f"Hello from child, {name}"
8435+
8436+
async def test_search_attribute_codec(client: Client):
8437+
8438+
config = client.config()
8439+
config["data_converter"] = dataclasses.replace(temporalio.converter.default(), payload_codec=EncryptionCodec())
8440+
client = Client(**config)
8441+
# Run a worker for the workflow
8442+
async with Worker(
8443+
client,
8444+
task_queue="encryption-task-queue",
8445+
workflows=[GreetingWorkflow, ChildWorkflow],
8446+
):
8447+
# Run workflow
8448+
result = await client.execute_workflow(
8449+
GreetingWorkflow.run,
8450+
"Temporal",
8451+
id=f"encryption-workflow-id",
8452+
task_queue="encryption-task-queue",
8453+
search_attributes=TypedSearchAttributes(
8454+
[SearchAttributePair(SearchAttributeKey.for_keyword("show_name"), "test_show")]
8455+
),
8456+
)

0 commit comments

Comments
 (0)