Skip to content

Commit ea17e88

Browse files
committed
Static method generation
1 parent 41f51c4 commit ea17e88

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

temporalio/worker/_command_aware_visitor.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,23 @@ class CommandInfo:
4848
)
4949

5050

51+
def _create_override_method(
52+
parent_method: Any, command_type: CommandType.ValueType
53+
) -> Any:
54+
"""Create an override method that sets command context."""
55+
56+
async def override_method(self: Any, fs: VisitorFunctions, o: Any) -> None:
57+
with current_command(command_type, o.seq):
58+
await parent_method(self, fs, o)
59+
60+
return override_method
61+
62+
5163
class CommandAwarePayloadVisitor(PayloadVisitor):
5264
"""Payload visitor that sets command context during traversal.
5365
54-
Overridden methods are created for all workflow commands and activation jobs that have a 'seq'
55-
field.
66+
Override methods are created at class definition time for all workflow
67+
commands and activation jobs that have a 'seq' field.
5668
"""
5769

5870
_COMMAND_TYPE_MAP: dict[type[Any], Optional[CommandType.ValueType]] = {
@@ -80,41 +92,37 @@ class CommandAwarePayloadVisitor(PayloadVisitor):
8092
FireTimer: CommandType.COMMAND_TYPE_START_TIMER,
8193
}
8294

83-
def __init__(self, **kwargs: Any) -> None:
84-
"""Initialize the command-aware payload visitor."""
85-
super().__init__(**kwargs)
86-
self._create_override_methods()
87-
88-
def _create_override_methods(self) -> None:
89-
"""Dynamically create override methods for all protos with seq fields."""
90-
for proto_class in _get_workflow_command_protos_with_seq():
91-
if command_type := self._COMMAND_TYPE_MAP[proto_class]:
92-
self._add_override(
93-
proto_class, "coresdk_workflow_commands", command_type
94-
)
95-
for proto_class in _get_workflow_activation_job_protos_with_seq():
96-
if command_type := self._COMMAND_TYPE_MAP[proto_class]:
97-
self._add_override(
98-
proto_class, "coresdk_workflow_activation", command_type
99-
)
100-
101-
def _add_override(
102-
self, proto_class: Type[Any], module: str, command_type: CommandType.ValueType
103-
) -> None:
104-
"""Add an override method that sets command context."""
105-
method_name = f"_visit_{module}_{proto_class.__name__}"
106-
parent_method = getattr(PayloadVisitor, method_name, None)
107-
108-
if not parent_method:
109-
# No visitor method means no payload fields to visit
110-
return
11195

112-
async def override_method(fs: VisitorFunctions, o: Any) -> None:
113-
with current_command(command_type, o.seq):
114-
assert parent_method
115-
await parent_method(self, fs, o)
96+
# Add override methods to CommandAwarePayloadVisitor at class definition time
97+
def _add_class_overrides() -> None:
98+
"""Add override methods to CommandAwarePayloadVisitor class."""
99+
# Process workflow commands
100+
for proto_class in _get_workflow_command_protos_with_seq():
101+
if command_type := CommandAwarePayloadVisitor._COMMAND_TYPE_MAP.get(
102+
proto_class
103+
):
104+
method_name = f"_visit_coresdk_workflow_commands_{proto_class.__name__}"
105+
parent_method = getattr(PayloadVisitor, method_name, None)
106+
if parent_method:
107+
setattr(
108+
CommandAwarePayloadVisitor,
109+
method_name,
110+
_create_override_method(parent_method, command_type),
111+
)
116112

117-
setattr(self, method_name, override_method)
113+
# Process activation jobs
114+
for proto_class in _get_workflow_activation_job_protos_with_seq():
115+
if command_type := CommandAwarePayloadVisitor._COMMAND_TYPE_MAP.get(
116+
proto_class
117+
):
118+
method_name = f"_visit_coresdk_workflow_activation_{proto_class.__name__}"
119+
parent_method = getattr(PayloadVisitor, method_name, None)
120+
if parent_method:
121+
setattr(
122+
CommandAwarePayloadVisitor,
123+
method_name,
124+
_create_override_method(parent_method, command_type),
125+
)
118126

119127

120128
def _get_workflow_command_protos_with_seq() -> Iterator[Type[Any]]:
@@ -144,3 +152,7 @@ def current_command(
144152
finally:
145153
if token:
146154
current_command_info.reset(token)
155+
156+
157+
# Create all override methods on the class when the module is imported
158+
_add_class_overrides()

0 commit comments

Comments
 (0)