Skip to content

Commit 411d635

Browse files
committed
Update changes
1 parent 7dc6231 commit 411d635

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

src/strands/agent/agent.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,22 @@ def __init__(
289289
config_loader = AgentConfigLoader()
290290
configured_agent = config_loader.load_agent(agent_config)
291291

292+
# There is an odd mypy type discrepancy in this file for some reason
293+
# There is a mismatch between src.strands.* and strands.*
294+
# type/ignore annotations allow type checking to pass
295+
292296
# Override config values with any explicitly provided parameters
293297
if model is not None:
294298
configured_agent.model = (
295-
BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
299+
BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model # type: ignore
296300
)
297301
if messages is not None:
298302
configured_agent.messages = messages
299303
if tools is not None:
300304
# Need to reinitialize tool registry with new tools
301-
configured_agent.tool_registry = ToolRegistry()
305+
configured_agent.tool_registry = ToolRegistry() # type: ignore
302306
configured_agent.tool_registry.process_tools(tools)
303-
configured_agent.tool_registry.initialize_tools(load_tools_from_directory)
307+
configured_agent.tool_registry.initialize_tools(load_tools_from_directory) # type: ignore
304308
if system_prompt is not None:
305309
configured_agent.system_prompt = system_prompt
306310
if not isinstance(callback_handler, _DefaultCallbackHandlerSentinel):
@@ -309,7 +313,7 @@ def __init__(
309313
else:
310314
configured_agent.callback_handler = callback_handler
311315
if conversation_manager is not None:
312-
configured_agent.conversation_manager = conversation_manager
316+
configured_agent.conversation_manager = conversation_manager # type: ignore
313317
if trace_attributes is not None:
314318
configured_agent.trace_attributes = {}
315319
for k, v in trace_attributes.items():
@@ -325,17 +329,17 @@ def __init__(
325329
configured_agent.description = description
326330
if state is not None:
327331
if isinstance(state, dict):
328-
configured_agent.state = AgentState(state)
332+
configured_agent.state = AgentState(state) # type: ignore
329333
elif isinstance(state, AgentState):
330-
configured_agent.state = state
334+
configured_agent.state = state # type: ignore
331335
else:
332336
raise ValueError("state must be an AgentState object or a dict")
333337
if hooks is not None:
334338
for hook in hooks:
335-
configured_agent.hooks.add_hook(hook)
339+
configured_agent.hooks.add_hook(hook) # type: ignore
336340
if session_manager is not None:
337-
configured_agent._session_manager = session_manager
338-
configured_agent.hooks.add_hook(session_manager)
341+
configured_agent._session_manager = session_manager # type: ignore
342+
configured_agent.hooks.add_hook(session_manager) # type: ignore
339343

340344
# Override record_direct_tool_call and load_tools_from_directory only if explicitly provided
341345
if record_direct_tool_call is not None:
@@ -344,9 +348,9 @@ def __init__(
344348
configured_agent.load_tools_from_directory = load_tools_from_directory
345349
if load_tools_from_directory:
346350
if hasattr(configured_agent, "tool_watcher"):
347-
configured_agent.tool_watcher = ToolWatcher(tool_registry=configured_agent.tool_registry)
351+
configured_agent.tool_watcher = ToolWatcher(tool_registry=configured_agent.tool_registry) # type: ignore
348352
else:
349-
configured_agent.tool_watcher = ToolWatcher(tool_registry=configured_agent.tool_registry)
353+
configured_agent.tool_watcher = ToolWatcher(tool_registry=configured_agent.tool_registry) # type: ignore
350354

351355
# Copy all attributes from configured agent to self
352356
self.__dict__.update(configured_agent.__dict__)
@@ -455,10 +459,10 @@ def _load_config(self, config: Union[str, Path, Dict[str, Any]]) -> Dict[str, An
455459

456460
if suffix in [".yaml", ".yml"]:
457461
with open(config_path, "r", encoding="utf-8") as f:
458-
return yaml.safe_load(f)
462+
return dict[str, Any](yaml.safe_load(f))
459463
elif suffix == ".json":
460464
with open(config_path, "r", encoding="utf-8") as f:
461-
return json.load(f)
465+
return dict[str, Any](json.load(f))
462466
else:
463467
raise ValueError(f"Unsupported config file format: {suffix}. Supported formats: .yaml, .yml, .json")
464468

src/strands/multiagent/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,12 @@ def from_config(cls, config: Union[str, Path, Dict[str, Any]]) -> "GraphBuilder"
275275
if "nodes" in graph_config:
276276
nodes = loader._load_nodes(graph_config["nodes"])
277277
for node_id, node in nodes.items():
278-
builder.nodes[node_id] = node
278+
builder.nodes[node_id] = node # type: ignore
279279

280280
# Load edges with conditions
281281
if "edges" in graph_config:
282-
edges = loader._load_edges(graph_config["edges"], builder.nodes)
283-
builder.edges = edges
282+
edges = loader._load_edges(graph_config["edges"], builder.nodes) # type: ignore
283+
builder.edges = edges # type: ignore
284284

285285
# Load entry points
286286
if "entry_points" in graph_config:
@@ -311,9 +311,9 @@ def _load_config_file(config_path: Union[str, Path]) -> Dict[str, Any]:
311311

312312
with open(path, "r") as f:
313313
if path.suffix.lower() in [".yaml", ".yml"]:
314-
return yaml.safe_load(f)
314+
return dict[str, Any](yaml.safe_load(f))
315315
elif path.suffix.lower() == ".json":
316-
return json.load(f)
316+
return dict[str, Any](json.load(f))
317317
else:
318318
raise ValueError(f"Unsupported config file format: {path.suffix}")
319319

src/strands/multiagent/swarm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
self.repetitive_handoff_min_unique_agents = repetitive_handoff_min_unique_agents
303303

304304
# Setup swarm with loaded agents
305-
self._setup_swarm(agents)
305+
self._setup_swarm(agents) # type: ignore
306306
self._inject_swarm_tools()
307307

308308
# Handle traditional initialization with nodes
@@ -314,7 +314,7 @@ def __init__(
314314
self.repetitive_handoff_detection_window = repetitive_handoff_detection_window
315315
self.repetitive_handoff_min_unique_agents = repetitive_handoff_min_unique_agents
316316

317-
self._setup_swarm(nodes)
317+
self._setup_swarm(nodes) # type: ignore
318318
self._inject_swarm_tools()
319319

320320
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
@@ -750,10 +750,10 @@ def _load_config(self, config: Union[str, Path, Dict[str, Any]]) -> Dict[str, An
750750

751751
if suffix in [".yaml", ".yml"]:
752752
with open(config_path, "r", encoding="utf-8") as f:
753-
return yaml.safe_load(f)
753+
return dict[str, Any](yaml.safe_load(f))
754754
elif suffix == ".json":
755755
with open(config_path, "r", encoding="utf-8") as f:
756-
return json.load(f)
756+
return dict[str, Any](json.load(f))
757757
else:
758758
raise ValueError(f"Unsupported config file format: {suffix}. Supported formats: .yaml, .yml, .json")
759759

0 commit comments

Comments
 (0)