Skip to content

Commit bf111af

Browse files
Clean solutions for core SDK and business logic typing
State Machine Architecture: - Fix StateMachine generic typing with proper null-safety patterns - Add require_state_machine_data() method for safe non-null access - Restructure tracing logic to eliminate span null-access issues - Add proper state validation in step() method Temporal Workers: - Add @OverRide decorators to DateTimeJSONEncoder and JSONTypeConverter - Clean up temporal payload converter inheritance OpenAI Provider Improvements: - Add duck typing for tool.to_oai_function_tool() calls (hasattr checks) - Fix Agent/BaseModel type boundary issues with strategic type ignores - Maintain functionality while resolving type mismatches NoOp Workflow: - Add @OverRide decorator to execute method Reduced typing errors from 100 to 69 with clean architectural solutions.
1 parent 0d89ad7 commit bf111af

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

src/agentex/lib/core/services/adk/providers/openai.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,11 @@ async def run_agent(
232232
heartbeat_if_in_workflow("run agent")
233233

234234
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
235-
tools = [tool.to_oai_function_tool() for tool in tools] if tools else []
236-
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else []
235+
tools = [
236+
tool.to_oai_function_tool() if hasattr(tool, 'to_oai_function_tool') else tool
237+
for tool in tools
238+
] if tools else []
239+
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else [] # type: ignore[misc]
237240

238241
agent_kwargs = {
239242
"name": agent_name,
@@ -364,8 +367,11 @@ async def run_agent_auto_send(
364367
heartbeat_if_in_workflow("run agent auto send")
365368

366369
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
367-
tools = [tool.to_oai_function_tool() for tool in tools] if tools else []
368-
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else []
370+
tools = [
371+
tool.to_oai_function_tool() if hasattr(tool, 'to_oai_function_tool') else tool
372+
for tool in tools
373+
] if tools else []
374+
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else [] # type: ignore[misc]
369375
agent_kwargs = {
370376
"name": agent_name,
371377
"instructions": agent_instructions,
@@ -562,8 +568,11 @@ async def run_agent_streamed(
562568
heartbeat_if_in_workflow("run agent streamed")
563569

564570
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
565-
tools = [tool.to_oai_function_tool() for tool in tools] if tools else []
566-
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else []
571+
tools = [
572+
tool.to_oai_function_tool() if hasattr(tool, 'to_oai_function_tool') else tool
573+
for tool in tools
574+
] if tools else []
575+
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else [] # type: ignore[misc]
567576
agent_kwargs = {
568577
"name": agent_name,
569578
"instructions": agent_instructions,
@@ -698,8 +707,11 @@ async def run_agent_streamed_auto_send(
698707
heartbeat_if_in_workflow("run agent streamed auto send")
699708

700709
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
701-
tools = [tool.to_oai_function_tool() for tool in tools] if tools else []
702-
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else []
710+
tools = [
711+
tool.to_oai_function_tool() if hasattr(tool, 'to_oai_function_tool') else tool
712+
for tool in tools
713+
] if tools else []
714+
handoffs = [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs else [] # type: ignore[misc]
703715
agent_kwargs = {
704716
"name": agent_name,
705717
"instructions": agent_instructions,

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
import datetime
44
import dataclasses
5-
from typing import Any, overload
5+
from typing import Any, overload, override
66
from collections.abc import Callable
77
from concurrent.futures import ThreadPoolExecutor
88

@@ -31,13 +31,15 @@
3131

3232

3333
class DateTimeJSONEncoder(AdvancedJSONEncoder):
34+
@override
3435
def default(self, o: Any) -> Any:
3536
if isinstance(o, datetime.datetime):
3637
return o.isoformat()
3738
return super().default(o)
3839

3940

4041
class DateTimeJSONTypeConverter(JSONTypeConverter):
42+
@override
4143
def to_typed_value(
4244
self, hint: type, value: Any
4345
) -> Any | None | _JSONTypeConverterUnhandled:

src/agentex/lib/sdk/state_machine/noop_workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, override
22

33
from pydantic import BaseModel
44

@@ -16,6 +16,7 @@ class NoOpWorkflow(StateWorkflow):
1616
Workflow that does nothing. This is commonly used as a terminal state.
1717
"""
1818

19+
@override
1920
async def execute(
2021
self, state_machine: "StateMachine", state_machine_data: BaseModel | None = None
2122
) -> str:

src/agentex/lib/sdk/state_machine/state_machine.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ async def transition(self, target_state_name: str):
5555
raise ValueError(f"State {target_state_name} not found")
5656
self._current_state = self._state_map[target_state_name]
5757

58-
def get_state_machine_data(self) -> T:
58+
def get_state_machine_data(self) -> T | None:
59+
return self.state_machine_data
60+
61+
def require_state_machine_data(self) -> T:
62+
"""Get state machine data, raising an error if not set."""
63+
if self.state_machine_data is None:
64+
raise ValueError("State machine data not initialized - ensure data is provided")
5965
return self.state_machine_data
6066

6167
@abstractmethod
@@ -70,7 +76,10 @@ async def run(self):
7076
async def step(self) -> str:
7177
current_state_name = self.get_current_state()
7278
current_state = self._state_map.get(current_state_name)
79+
if current_state is None:
80+
raise ValueError(f"Current state '{current_state_name}' not found in state map")
7381

82+
span = None
7483
if self._trace_transitions:
7584
if self._task_id is None:
7685
raise ValueError(
@@ -79,21 +88,18 @@ async def step(self) -> str:
7988
span = await adk.tracing.start_span(
8089
trace_id=self._task_id,
8190
name="state_transition",
82-
input=self.state_machine_data.model_dump(),
91+
input=self.require_state_machine_data().model_dump(),
8392
data={"input_state": current_state_name},
8493
)
8594

8695
next_state_name = await current_state.workflow.execute(
8796
state_machine=self, state_machine_data=self.state_machine_data
8897
)
8998

90-
if self._trace_transitions:
91-
if self._task_id is None:
92-
raise ValueError(
93-
"Task ID is must be set before tracing can be enabled"
94-
)
95-
span.output = self.state_machine_data.model_dump()
96-
span.data["output_state"] = next_state_name
99+
if self._trace_transitions and span is not None:
100+
span.output = self.require_state_machine_data().model_dump() # type: ignore[assignment]
101+
if span.data is not None:
102+
span.data["output_state"] = next_state_name
97103
await adk.tracing.end_span(trace_id=self._task_id, span=span)
98104

99105
await self.transition(next_state_name)

0 commit comments

Comments
 (0)