@@ -48,11 +48,23 @@ class CommandInfo:
4848)
4949
5050
51+ def _create_override_method (
52+ parent_method : Any , command_type : CommandType .ValueType
53+ ) -> Any :
54+ """Create an override method that sets command context."""
55+
56+ async def override_method (self : Any , fs : VisitorFunctions , o : Any ) -> None :
57+ with current_command (command_type , o .seq ):
58+ await parent_method (self , fs , o )
59+
60+ return override_method
61+
62+
5163class CommandAwarePayloadVisitor (PayloadVisitor ):
5264 """Payload visitor that sets command context during traversal.
5365
54- Overridden methods are created for all workflow commands and activation jobs that have a 'seq'
55- field.
66+ Override methods are created at class definition time for all workflow
67+ commands and activation jobs that have a 'seq' field.
5668 """
5769
5870 _COMMAND_TYPE_MAP : dict [type [Any ], Optional [CommandType .ValueType ]] = {
@@ -80,41 +92,37 @@ class CommandAwarePayloadVisitor(PayloadVisitor):
8092 FireTimer : CommandType .COMMAND_TYPE_START_TIMER ,
8193 }
8294
83- def __init__ (self , ** kwargs : Any ) -> None :
84- """Initialize the command-aware payload visitor."""
85- super ().__init__ (** kwargs )
86- self ._create_override_methods ()
87-
88- def _create_override_methods (self ) -> None :
89- """Dynamically create override methods for all protos with seq fields."""
90- for proto_class in _get_workflow_command_protos_with_seq ():
91- if command_type := self ._COMMAND_TYPE_MAP [proto_class ]:
92- self ._add_override (
93- proto_class , "coresdk_workflow_commands" , command_type
94- )
95- for proto_class in _get_workflow_activation_job_protos_with_seq ():
96- if command_type := self ._COMMAND_TYPE_MAP [proto_class ]:
97- self ._add_override (
98- proto_class , "coresdk_workflow_activation" , command_type
99- )
100-
101- def _add_override (
102- self , proto_class : Type [Any ], module : str , command_type : CommandType .ValueType
103- ) -> None :
104- """Add an override method that sets command context."""
105- method_name = f"_visit_{ module } _{ proto_class .__name__ } "
106- parent_method = getattr (PayloadVisitor , method_name , None )
107-
108- if not parent_method :
109- # No visitor method means no payload fields to visit
110- return
11195
112- async def override_method (fs : VisitorFunctions , o : Any ) -> None :
113- with current_command (command_type , o .seq ):
114- assert parent_method
115- await parent_method (self , fs , o )
96+ # Add override methods to CommandAwarePayloadVisitor at class definition time
97+ def _add_class_overrides () -> None :
98+ """Add override methods to CommandAwarePayloadVisitor class."""
99+ # Process workflow commands
100+ for proto_class in _get_workflow_command_protos_with_seq ():
101+ if command_type := CommandAwarePayloadVisitor ._COMMAND_TYPE_MAP .get (
102+ proto_class
103+ ):
104+ method_name = f"_visit_coresdk_workflow_commands_{ proto_class .__name__ } "
105+ parent_method = getattr (PayloadVisitor , method_name , None )
106+ if parent_method :
107+ setattr (
108+ CommandAwarePayloadVisitor ,
109+ method_name ,
110+ _create_override_method (parent_method , command_type ),
111+ )
116112
117- setattr (self , method_name , override_method )
113+ # Process activation jobs
114+ for proto_class in _get_workflow_activation_job_protos_with_seq ():
115+ if command_type := CommandAwarePayloadVisitor ._COMMAND_TYPE_MAP .get (
116+ proto_class
117+ ):
118+ method_name = f"_visit_coresdk_workflow_activation_{ proto_class .__name__ } "
119+ parent_method = getattr (PayloadVisitor , method_name , None )
120+ if parent_method :
121+ setattr (
122+ CommandAwarePayloadVisitor ,
123+ method_name ,
124+ _create_override_method (parent_method , command_type ),
125+ )
118126
119127
120128def _get_workflow_command_protos_with_seq () -> Iterator [Type [Any ]]:
@@ -144,3 +152,7 @@ def current_command(
144152 finally :
145153 if token :
146154 current_command_info .reset (token )
155+
156+
157+ # Create all override methods on the class when the module is imported
158+ _add_class_overrides ()
0 commit comments