Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,37 @@ async def _execute_handoff(
prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
agent_max_step = int(prov_settings.get("max_agent_step", 30))
stream = prov_settings.get("streaming_response", False)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
image_urls=image_urls,
system_prompt=tool.agent.instructions,
tools=toolset,
contexts=contexts,
max_steps=agent_max_step,
tool_call_timeout=run_context.tool_call_timeout,
stream=stream,
)
orchestrator = getattr(ctx, "subagent_orchestrator", None)
subagent_runner = getattr(orchestrator, "runner", None)
if subagent_runner is not None:
llm_resp = await subagent_runner.run(
tool=tool,
run_context=run_context,
event=event,
ctx=ctx,
provider_id=prov_id,
input_=input_,
image_urls=image_urls,
system_prompt=tool.agent.instructions,
tools=toolset,
begin_contexts=contexts,
max_steps=agent_max_step,
tool_call_timeout=run_context.tool_call_timeout,
stream=stream,
)
else:
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
image_urls=image_urls,
system_prompt=tool.agent.instructions,
tools=toolset,
contexts=contexts,
max_steps=agent_max_step,
tool_call_timeout=run_context.tool_call_timeout,
stream=stream,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
Expand Down
43 changes: 42 additions & 1 deletion astrbot/core/star/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol

from deprecated import deprecated

from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
Expand Down Expand Up @@ -53,6 +55,12 @@
RegisteredWebApi = tuple[str, WebApiHandler, list[str], str]


@dataclass
class _ToolLoopAgentRunResult:
llm_response: LLMResponse
run_context: ContextWrapper[Any]


class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]

Expand Down Expand Up @@ -192,6 +200,36 @@ async def tool_loop_agent(
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
result = await self._run_tool_loop_agent_internal(
event=event,
chat_provider_id=chat_provider_id,
prompt=prompt,
image_urls=image_urls,
audio_urls=audio_urls,
tools=tools,
system_prompt=system_prompt,
contexts=contexts,
max_steps=max_steps,
tool_call_timeout=tool_call_timeout,
**kwargs,
)
return result.llm_response

async def _run_tool_loop_agent_internal(
self,
*,
event: AstrMessageEvent,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
audio_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
max_steps: int = 30,
tool_call_timeout: int = 120,
**kwargs: Any,
) -> _ToolLoopAgentRunResult:
# Import here to avoid circular imports
from astrbot.core.astr_agent_context import (
AgentContextWrapper,
Expand Down Expand Up @@ -261,7 +299,10 @@ async def tool_loop_agent(
llm_resp = agent_runner.get_final_llm_resp()
if not llm_resp:
raise Exception("Agent did not produce a final LLM response")
return llm_resp
return _ToolLoopAgentRunResult(
llm_response=llm_resp,
run_context=agent_runner.run_context,
)

async def get_current_chat_provider_id(self, umo: str) -> str:
"""获取当前使用的聊天模型 Provider ID。
Expand Down
26 changes: 26 additions & 0 deletions astrbot/core/subagent_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.subagent_runner import (
SubAgentRunner,
SubAgentSessionManager,
build_subagent_config_fingerprint,
normalize_context_persistence,
)

if TYPE_CHECKING:
from astrbot.core.persona_mgr import PersonaManager
Expand All @@ -25,6 +31,8 @@ def __init__(
self._tool_mgr = tool_mgr
self._persona_mgr = persona_mgr
self.handoffs: list[HandoffTool] = []
self.session_manager = SubAgentSessionManager()
self.runner = SubAgentRunner(self.session_manager)

async def reload_from_config(self, cfg: dict[str, Any]) -> None:
from astrbot.core.astr_agent_context import AstrAgentContext
Expand All @@ -35,6 +43,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None:
return

handoffs: list[HandoffTool] = []
persistent_agent_names: set[str] = set()
for item in agents:
if not isinstance(item, dict):
continue
Expand All @@ -60,6 +69,9 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None:
provider_id = item.get("provider_id")
if provider_id is not None:
provider_id = str(provider_id).strip() or None
context_persistence = normalize_context_persistence(
item.get("context_persistence")
)
tools = item.get("tools", [])
begin_dialogs = None

Expand Down Expand Up @@ -95,10 +107,24 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None:

# Optional per-subagent chat provider override.
handoff.provider_id = provider_id
handoff.context_persistence = context_persistence
handoff.config_fingerprint = build_subagent_config_fingerprint(
{
"name": name,
"persona_id": persona_id,
"instructions": instructions,
"tools": tools,
"provider_id": provider_id,
"context_persistence": context_persistence,
}
)
if context_persistence["enable"]:
persistent_agent_names.add(name)

handoffs.append(handoff)

for handoff in handoffs:
logger.info(f"Registered subagent handoff tool: {handoff.name}")

self.handoffs = handoffs
self.session_manager.clear_except_agents(persistent_agent_names)
Loading