Skip to content

Commit 00ef528

Browse files
committed
Linting
1 parent 7ffa586 commit 00ef528

File tree

10 files changed

+134
-45
lines changed

10 files changed

+134
-45
lines changed

scripts/gen_protos.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def generate_protos(output_dir: Path):
201201
/ v,
202202
)
203203

204+
204205
if __name__ == "__main__":
205206
check_proto_toolchain_versions()
206207
print("Generating protos...", file=sys.stderr)

scripts/gen_visitors.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
base_dir = Path(__file__).parent.parent
1515

16+
1617
def gen_workflow_activation_payload_visitor_code() -> str:
1718
"""
1819
Generate Python source code that, given a function f(Payload) -> Payload,
@@ -23,29 +24,44 @@ def gen_workflow_activation_payload_visitor_code() -> str:
2324
for repeated fields and map entries, and a convenience entrypoint
2425
function `visit_workflow_activation_payloads`.
2526
"""
27+
2628
def name_for(desc: Descriptor) -> str:
2729
# Use fully-qualified name to avoid collisions; replace dots with underscores
28-
return desc.full_name.replace('.', '_')
29-
30-
def emit_loop(lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str) -> None:
30+
return desc.full_name.replace(".", "_")
31+
32+
def emit_loop(
33+
lines: list[str],
34+
field_name: str,
35+
iter_expr: str,
36+
var_name: str,
37+
child_method: str,
38+
) -> None:
3139
# Helper to emit a for-loop over a collection with optional headers guard
3240
if field_name == "headers":
3341
lines.append(" if not self.skip_headers:")
3442
lines.append(f" for {var_name} in {iter_expr}:")
35-
lines.append(f" await self.visit_{child_method}(f, {var_name})")
43+
lines.append(
44+
f" await self.visit_{child_method}(f, {var_name})"
45+
)
3646
else:
3747
lines.append(f" for {var_name} in {iter_expr}:")
3848
lines.append(f" await self.visit_{child_method}(f, {var_name})")
3949

40-
def emit_singular(lines: list[str], field_name: str, access_expr: str, child_method: str) -> None:
50+
def emit_singular(
51+
lines: list[str], field_name: str, access_expr: str, child_method: str
52+
) -> None:
4153
# Helper to emit a singular field visit with presence check and optional headers guard
4254
if field_name == "headers":
4355
lines.append(" if not self.skip_headers:")
4456
lines.append(f" if o.HasField('{field_name}'):")
45-
lines.append(f" await self.visit_{child_method}(f, {access_expr})")
57+
lines.append(
58+
f" await self.visit_{child_method}(f, {access_expr})"
59+
)
4660
else:
4761
lines.append(f" if o.HasField('{field_name}'):")
48-
lines.append(f" await self.visit_{child_method}(f, {access_expr})")
62+
lines.append(
63+
f" await self.visit_{child_method}(f, {access_expr})"
64+
)
4965

5066
# Track which message descriptors have visitor methods generated
5167
generated: dict[str, bool] = {}
@@ -83,36 +99,65 @@ def walk(desc: Descriptor) -> bool:
8399

84100
# Repeated fields (including maps which are represented as repeated messages)
85101
if field.label == FieldDescriptor.LABEL_REPEATED:
86-
if field.message_type is not None and field.message_type.GetOptions().map_entry:
102+
if (
103+
field.message_type is not None
104+
and field.message_type.GetOptions().map_entry
105+
):
87106
entry_desc = field.message_type
88107
key_fd = entry_desc.fields_by_name.get("key")
89108
val_fd = entry_desc.fields_by_name.get("value")
90109

91-
if val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE:
110+
if (
111+
val_fd is not None
112+
and val_fd.type == FieldDescriptor.TYPE_MESSAGE
113+
):
92114
child_desc = val_fd.message_type
93115
child_needed = walk(child_desc)
94116
needed |= child_needed
95117
if child_needed:
96-
emit_loop(lines, field.name, f"o.{field.name}.values()", "v", name_for(child_desc))
97-
98-
if key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE:
118+
emit_loop(
119+
lines,
120+
field.name,
121+
f"o.{field.name}.values()",
122+
"v",
123+
name_for(child_desc),
124+
)
125+
126+
if (
127+
key_fd is not None
128+
and key_fd.type == FieldDescriptor.TYPE_MESSAGE
129+
):
99130
key_desc = key_fd.message_type
100131
child_needed = walk(key_desc)
101132
needed |= child_needed
102133
if child_needed:
103-
emit_loop(lines, field.name, f"o.{field.name}.keys()", "k", name_for(key_desc))
134+
emit_loop(
135+
lines,
136+
field.name,
137+
f"o.{field.name}.keys()",
138+
"k",
139+
name_for(key_desc),
140+
)
104141
else:
105142
child_desc = field.message_type
106143
child_needed = walk(child_desc)
107144
needed |= child_needed
108145
if child_needed:
109-
emit_loop(lines, field.name, f"o.{field.name}", "v", name_for(child_desc))
146+
emit_loop(
147+
lines,
148+
field.name,
149+
f"o.{field.name}",
150+
"v",
151+
name_for(child_desc),
152+
)
110153
else:
111154
child_desc = field.message_type
112155
child_needed = walk(child_desc)
113156
needed |= child_needed
114157
if child_needed:
115-
emit_singular(lines, field.name, f"o.{field.name}", name_for(child_desc))
158+
emit_singular(
159+
lines, field.name, f"o.{field.name}", name_for(child_desc)
160+
)
116161

117162
generated[key] = needed
118163
in_progress.discard(key)
@@ -134,7 +179,7 @@ def walk(desc: Descriptor) -> bool:
134179
walk(r)
135180

136181
header = (
137-
"from typing import Awaitable, Callable, Any\n\n"
182+
"from typing import Any, Awaitable, Callable\n\n"
138183
"from temporalio.api.common.v1.message_pb2 import Payload\n\n\n"
139184
"class PayloadVisitor:\n"
140185
" def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n"
@@ -156,7 +201,7 @@ def write_generated_visitors_into_visitor_generated_py() -> None:
156201
code = gen_workflow_activation_payload_visitor_code()
157202
out_path.write_text(code)
158203

204+
159205
if __name__ == "__main__":
160206
print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr)
161207
write_generated_visitors_into_visitor_generated_py()
162-

temporalio/bridge/visitor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class PayloadVisitor:
12-
def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):
12+
def __init__(
13+
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
14+
):
1315
self.skip_search_attributes = skip_search_attributes
1416
self.skip_headers = skip_headers
1517

@@ -31,8 +33,10 @@ async def visit_payloads(
3133
for o in root:
3234
await self.visit_payloads(f, o)
3335
elif isinstance(root, Message):
34-
await self.visit_message(f, root,)
35-
36+
await self.visit_message(
37+
f,
38+
root,
39+
)
3640

3741
async def visit_message(
3842
self, f: Callable[[Payload], Awaitable[Payload]], root: Message
@@ -44,7 +48,10 @@ async def visit_message(
4448
# Repeated fields (including maps which are represented as repeated messages)
4549
if field.label == FieldDescriptor.LABEL_REPEATED:
4650
value = getattr(root, field.name)
47-
if field.message_type is not None and field.message_type.GetOptions().map_entry:
51+
if (
52+
field.message_type is not None
53+
and field.message_type.GetOptions().map_entry
54+
):
4855
for k, v in value.items():
4956
await self.visit_payloads(f, k)
5057
await self.visit_payloads(f, v)
@@ -53,6 +60,8 @@ async def visit_message(
5360
await self.visit_payloads(f, item)
5461
else:
5562
# Only descend into singular message fields if present
56-
if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name):
63+
if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(
64+
field.name
65+
):
5766
value = getattr(root, field.name)
5867
await self.visit_payloads(f, value)

temporalio/bridge/visitor_generated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Awaitable, Callable, Any
1+
from typing import Any, Awaitable, Callable
22

33
from temporalio.api.common.v1.message_pb2 import Payload
44

temporalio/bridge/worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,25 @@ async def decode_activation(
372372
decode_headers: bool,
373373
) -> None:
374374
"""Decode the given activation with the codec."""
375+
375376
async def visitor(payload: Payload) -> Payload:
376377
return (await codec.decode([payload]))[0]
377378

378-
await PayloadVisitor(skip_search_attributes=True, skip_headers=not decode_headers).visit_message(visitor, act)
379+
await PayloadVisitor(
380+
skip_search_attributes=True, skip_headers=not decode_headers
381+
).visit_message(visitor, act)
382+
379383

380384
async def encode_completion(
381385
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
382386
codec: temporalio.converter.PayloadCodec,
383387
encode_headers: bool,
384388
) -> None:
385389
"""Recursively encode the given completion with the codec."""
390+
386391
async def visitor(payload: Payload) -> Payload:
387392
return (await codec.encode([payload]))[0]
388393

389-
await PayloadVisitor(skip_search_attributes=True, skip_headers=not encode_headers).visit_message(visitor, comp)
394+
await PayloadVisitor(
395+
skip_search_attributes=True, skip_headers=not encode_headers
396+
).visit_message(visitor, comp)

temporalio/worker/_workflow_instance.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class WorkflowInstanceDetails:
146146
worker_level_failure_exception_types: Sequence[Type[BaseException]]
147147
last_completion_result: Payloads
148148

149+
149150
class WorkflowInstance(ABC):
150151
"""Instance of a workflow that can handle activations."""
151152

@@ -1689,16 +1690,28 @@ def workflow_set_current_details(self, details: str):
16891690
self._assert_not_read_only("set current details")
16901691
self._current_details = details
16911692

1692-
def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]:
1693-
print("workflow_last_completion_result: ", self._last_completion_result, type(self._last_completion_result), "payload length:", len(self._last_completion_result.payloads))
1693+
def workflow_last_completion_result(
1694+
self, type_hint: Optional[Type]
1695+
) -> Optional[Any]:
1696+
print(
1697+
"workflow_last_completion_result: ",
1698+
self._last_completion_result,
1699+
type(self._last_completion_result),
1700+
"payload length:",
1701+
len(self._last_completion_result.payloads),
1702+
)
16941703
if len(self._last_completion_result.payloads) == 0:
16951704
return None
16961705
elif len(self._last_completion_result.payloads) > 1:
1697-
warnings.warn(f"Expected single last completion result, got {len(self._last_completion_result.payloads)}")
1706+
warnings.warn(
1707+
f"Expected single last completion result, got {len(self._last_completion_result.payloads)}"
1708+
)
16981709
return None
16991710

17001711
print("Payload:", self._last_completion_result.payloads[0])
1701-
return self._payload_converter.from_payload(self._last_completion_result.payloads[0], type_hint)
1712+
return self._payload_converter.from_payload(
1713+
self._last_completion_result.payloads[0], type_hint
1714+
)
17021715

17031716
#### Calls from outbound impl ####
17041717
# These are in alphabetical order and all start with "_outbound_".

temporalio/workflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,9 @@ def workflow_get_current_details(self) -> str: ...
898898
def workflow_set_current_details(self, details: str): ...
899899

900900
@abstractmethod
901-
def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]: ...
901+
def workflow_last_completion_result(
902+
self, type_hint: Optional[Type]
903+
) -> Optional[Any]: ...
902904

903905

904906
_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(

tests/test_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1541,4 +1541,3 @@ async def test_schedule_last_completion_result(
15411541

15421542
await handle.delete()
15431543
assert False
1544-

tests/worker/test_visitor.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ async def test_workflow_activation_completion():
4444
schedule_to_close_timeout=Duration(seconds=5),
4545
priority=Priority(),
4646
),
47-
user_metadata=UserMetadata(
48-
summary=Payload(data=b"Summary")
49-
),
47+
user_metadata=UserMetadata(summary=Payload(data=b"Summary")),
5048
)
5149
],
5250
),
@@ -78,9 +76,7 @@ async def test_workflow_activation():
7876
Payload(data=b"repeated1"),
7977
Payload(data=b"repeated2"),
8078
],
81-
headers={
82-
"header":Payload(data=b"map")
83-
},
79+
headers={"header": Payload(data=b"map")},
8480
last_completion_result=Payloads(
8581
payloads=[
8682
Payload(data=b"obj1"),
@@ -89,9 +85,9 @@ async def test_workflow_activation():
8985
),
9086
search_attributes=SearchAttributes(
9187
indexed_fields={
92-
"sakey":Payload(data=b"saobj"),
88+
"sakey": Payload(data=b"saobj"),
9389
}
94-
)
90+
),
9591
),
9692
)
9793
]
@@ -110,13 +106,29 @@ async def visitor(payload: Payload) -> Payload:
110106
assert act.jobs[0].initialize_workflow.arguments[0].metadata["visited"]
111107
assert act.jobs[0].initialize_workflow.arguments[1].metadata["visited"]
112108
assert act.jobs[0].initialize_workflow.headers["header"].metadata["visited"]
113-
assert act.jobs[0].initialize_workflow.last_completion_result.payloads[0].metadata["visited"]
114-
assert act.jobs[0].initialize_workflow.last_completion_result.payloads[1].metadata["visited"]
115-
assert act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"]
109+
assert (
110+
act.jobs[0]
111+
.initialize_workflow.last_completion_result.payloads[0]
112+
.metadata["visited"]
113+
)
114+
assert (
115+
act.jobs[0]
116+
.initialize_workflow.last_completion_result.payloads[1]
117+
.metadata["visited"]
118+
)
119+
assert (
120+
act.jobs[0]
121+
.initialize_workflow.search_attributes.indexed_fields["sakey"]
122+
.metadata["visited"]
123+
)
116124

117125
act = original.__deepcopy__()
118126
await PayloadVisitor(skip_search_attributes=True).visit(visitor, act)
119-
assert not act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"]
127+
assert (
128+
not act.jobs[0]
129+
.initialize_workflow.search_attributes.indexed_fields["sakey"]
130+
.metadata["visited"]
131+
)
120132

121133
act = original.__deepcopy__()
122134
await PayloadVisitor(skip_headers=True).visit(visitor, act)
@@ -198,6 +210,7 @@ async def visitor(payload: Payload) -> Payload:
198210
ur = cmds[4].update_response
199211
assert ur.completed.data == b"visited:e1"
200212

213+
201214
async def test_code_gen():
202215
# Smoke test the generated visitor on a simple activation containing payloads
203216
act = WorkflowActivation(
@@ -221,4 +234,4 @@ async def _f(p: Payload) -> Payload:
221234
init = act.jobs[0].initialize_workflow
222235
assert init.arguments[0].data == b"v:x1"
223236
assert init.arguments[1].data == b"v:x2"
224-
assert init.headers["h"].data == b"v:x3"
237+
assert init.headers["h"].data == b"v:x3"

tests/worker/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8329,4 +8329,4 @@ async def test_workflow_headers_with_codec(
83298329
assert headers["foo"].data == b"bar"
83308330
else:
83318331
assert headers["foo"].data != b"bar"
8332-
assert False
8332+
assert False

0 commit comments

Comments
 (0)