diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..6b7806ab04 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -351,18 +351,37 @@ async def _execute_handoff( prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) agent_max_step = int(prov_settings.get("max_agent_step", 30)) stream = prov_settings.get("streaming_response", False) - llm_resp = await ctx.tool_loop_agent( - event=event, - chat_provider_id=prov_id, - prompt=input_, - image_urls=image_urls, - system_prompt=tool.agent.instructions, - tools=toolset, - contexts=contexts, - max_steps=agent_max_step, - tool_call_timeout=run_context.tool_call_timeout, - stream=stream, - ) + orchestrator = getattr(ctx, "subagent_orchestrator", None) + subagent_runner = getattr(orchestrator, "runner", None) + if subagent_runner is not None: + llm_resp = await subagent_runner.run( + tool=tool, + run_context=run_context, + event=event, + ctx=ctx, + provider_id=prov_id, + input_=input_, + image_urls=image_urls, + system_prompt=tool.agent.instructions, + tools=toolset, + begin_contexts=contexts, + max_steps=agent_max_step, + tool_call_timeout=run_context.tool_call_timeout, + stream=stream, + ) + else: + llm_resp = await ctx.tool_loop_agent( + event=event, + chat_provider_id=prov_id, + prompt=input_, + image_urls=image_urls, + system_prompt=tool.agent.instructions, + tools=toolset, + contexts=contexts, + max_steps=agent_max_step, + tool_call_timeout=run_context.tool_call_timeout, + stream=stream, + ) yield mcp.types.CallToolResult( content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] ) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 593bad9365..9f3f1bc5fa 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -3,12 +3,14 @@ import logging from asyncio import Queue from collections.abc import Awaitable, Callable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Protocol from deprecated import deprecated from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -53,6 +55,12 @@ RegisteredWebApi = tuple[str, WebApiHandler, list[str], str] +@dataclass +class _ToolLoopAgentRunResult: + llm_response: LLMResponse + run_context: ContextWrapper[Any] + + class PlatformManagerProtocol(Protocol): platform_insts: list[Platform] @@ -192,6 +200,36 @@ async def tool_loop_agent( ChatProviderNotFoundError: If the specified chat provider ID is not found Exception: For other errors during LLM generation """ + result = await self._run_tool_loop_agent_internal( + event=event, + chat_provider_id=chat_provider_id, + prompt=prompt, + image_urls=image_urls, + audio_urls=audio_urls, + tools=tools, + system_prompt=system_prompt, + contexts=contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + **kwargs, + ) + return result.llm_response + + async def _run_tool_loop_agent_internal( + self, + *, + event: AstrMessageEvent, + chat_provider_id: str, + prompt: str | None = None, + image_urls: list[str] | None = None, + audio_urls: list[str] | None = None, + tools: ToolSet | None = None, + system_prompt: str | None = None, + contexts: list[Message] | None = None, + max_steps: int = 30, + tool_call_timeout: int = 120, + **kwargs: Any, + ) -> _ToolLoopAgentRunResult: # Import here to avoid circular imports from astrbot.core.astr_agent_context import ( AgentContextWrapper, @@ -261,7 +299,10 @@ async def tool_loop_agent( llm_resp = agent_runner.get_final_llm_resp() if not llm_resp: raise Exception("Agent did not produce a final LLM response") - return llm_resp + return _ToolLoopAgentRunResult( + llm_response=llm_resp, + run_context=agent_runner.run_context, + ) async def get_current_chat_provider_id(self, umo: str) -> str: """获取当前使用的聊天模型 Provider ID。 diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index c6c595dfc9..e25e89e2f4 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -7,6 +7,12 @@ from astrbot.core.agent.agent import Agent from astrbot.core.agent.handoff import HandoffTool from astrbot.core.provider.func_tool_manager import FunctionToolManager +from astrbot.core.subagent_runner import ( + SubAgentRunner, + SubAgentSessionManager, + build_subagent_config_fingerprint, + normalize_context_persistence, +) if TYPE_CHECKING: from astrbot.core.persona_mgr import PersonaManager @@ -25,6 +31,8 @@ def __init__( self._tool_mgr = tool_mgr self._persona_mgr = persona_mgr self.handoffs: list[HandoffTool] = [] + self.session_manager = SubAgentSessionManager() + self.runner = SubAgentRunner(self.session_manager) async def reload_from_config(self, cfg: dict[str, Any]) -> None: from astrbot.core.astr_agent_context import AstrAgentContext @@ -35,6 +43,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: return handoffs: list[HandoffTool] = [] + persistent_agent_names: set[str] = set() for item in agents: if not isinstance(item, dict): continue @@ -60,6 +69,9 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: provider_id = item.get("provider_id") if provider_id is not None: provider_id = str(provider_id).strip() or None + context_persistence = normalize_context_persistence( + item.get("context_persistence") + ) tools = item.get("tools", []) begin_dialogs = None @@ -95,6 +107,19 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: # Optional per-subagent chat provider override. handoff.provider_id = provider_id + handoff.context_persistence = context_persistence + handoff.config_fingerprint = build_subagent_config_fingerprint( + { + "name": name, + "persona_id": persona_id, + "instructions": instructions, + "tools": tools, + "provider_id": provider_id, + "context_persistence": context_persistence, + } + ) + if context_persistence["enable"]: + persistent_agent_names.add(name) handoffs.append(handoff) @@ -102,3 +127,4 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: logger.info(f"Registered subagent handoff tool: {handoff.name}") self.handoffs = handoffs + self.session_manager.clear_except_agents(persistent_agent_names) diff --git a/astrbot/core/subagent_runner.py b/astrbot/core/subagent_runner.py new file mode 100644 index 0000000000..30f184e28c --- /dev/null +++ b/astrbot/core/subagent_runner.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from astrbot import logger +from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.provider.entities import LLMResponse + +if TYPE_CHECKING: + from astrbot.core.astr_agent_context import AstrAgentContext + +DEFAULT_CONTEXT_PERSISTENCE: dict[str, Any] = { + "enable": False, + "max_turns": 10, + "ttl_seconds": 3600, +} + + +def _positive_int(value: Any, default: int) -> int: + try: + parsed = int(float(value)) + except (TypeError, ValueError, OverflowError): + return default + return parsed if parsed > 0 else default + + +def _ttl_seconds(value: Any, default: int) -> int: + try: + parsed = int(float(value)) + except (TypeError, ValueError, OverflowError): + return default + if parsed == -1: + return parsed + return parsed if parsed > 0 else default + + +def normalize_context_persistence(raw: Any) -> dict[str, Any]: + defaults = DEFAULT_CONTEXT_PERSISTENCE + data = raw if isinstance(raw, dict) else {} + return { + "enable": bool(data.get("enable", defaults["enable"])), + "max_turns": _positive_int(data.get("max_turns"), defaults["max_turns"]), + "ttl_seconds": _ttl_seconds(data.get("ttl_seconds"), defaults["ttl_seconds"]), + } + + +def build_subagent_config_fingerprint(payload: dict[str, Any]) -> str: + stable_payload = json.dumps( + payload, sort_keys=True, ensure_ascii=False, default=str + ) + return hashlib.sha256(stable_payload.encode("utf-8")).hexdigest() + + +@dataclass +class _SubAgentContextRecord: + messages: list[Message] + last_used_at: float + config_fingerprint: str + + +class SubAgentSessionManager: + def __init__(self) -> None: + self._records: dict[tuple[str, str, str], _SubAgentContextRecord] = {} + self._locks: dict[tuple[str, str, str], asyncio.Lock] = {} + + def build_key( + self, + run_context: ContextWrapper[AstrAgentContext], + subagent_name: str, + ) -> tuple[str, str, str]: + event = run_context.context.event + unified_msg_origin = getattr(event, "unified_msg_origin", "") or "" + session_id = getattr(event, "session_id", "") or unified_msg_origin + return (unified_msg_origin, session_id, subagent_name) + + def get_lock(self, key: tuple[str, str, str]) -> asyncio.Lock: + lock = self._locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._locks[key] = lock + return lock + + def clear(self, key: tuple[str, str, str]) -> None: + self._records.pop(key, None) + lock = self._locks.get(key) + if lock is not None and not lock.locked(): + self._locks.pop(key, None) + + def clear_except_agents(self, agent_names: set[str]) -> None: + stale_keys = [key for key in self._records if key[2] not in agent_names] + for key in stale_keys: + self.clear(key) + + def get_messages( + self, + key: tuple[str, str, str], + *, + ttl_seconds: int, + config_fingerprint: str, + now: float | None = None, + ) -> list[Message] | None: + record = self._records.get(key) + if record is None: + return None + + now = time.monotonic() if now is None else now + if ttl_seconds != -1 and now - record.last_used_at > ttl_seconds: + self.clear(key) + return None + if record.config_fingerprint != config_fingerprint: + self.clear(key) + return None + return [self._clone_message(message) for message in record.messages] + + def set_messages( + self, + key: tuple[str, str, str], + messages: list[Message], + *, + config_fingerprint: str, + context_persistence: dict[str, Any], + now: float | None = None, + ) -> None: + trimmed = self._trim_messages( + messages, + max_turns=context_persistence["max_turns"], + ) + self._records[key] = _SubAgentContextRecord( + messages=trimmed, + last_used_at=time.monotonic() if now is None else now, + config_fingerprint=config_fingerprint, + ) + + def _trim_messages( + self, + messages: list[Message], + *, + max_turns: int, + ) -> list[Message]: + groups = self._group_messages(messages) + groups = self._trim_groups_by_turns(groups, max_turns) + return [message for group in groups for message in group] + + def _group_messages(self, messages: list[Message]) -> list[list[Message]]: + groups: list[list[Message]] = [] + index = 0 + while index < len(messages): + message = messages[index] + if message.role in {"system", "_checkpoint"}: + index += 1 + continue + if message.role == "tool": + index += 1 + continue + + cloned = self._clone_message(message) + group = [cloned] + if message.role == "assistant" and message.tool_calls: + expected_ids = { + tool_call_id + for tool_call in message.tool_calls + if (tool_call_id := self._tool_call_id(tool_call)) is not None + } + next_index = index + 1 + while next_index < len(messages): + next_message = messages[next_index] + if ( + next_message.role == "tool" + and next_message.tool_call_id in expected_ids + ): + group.append(self._clone_message(next_message)) + next_index += 1 + continue + break + index = next_index + else: + index += 1 + groups.append(group) + return groups + + def _trim_groups_by_turns( + self, groups: list[list[Message]], max_turns: int + ) -> list[list[Message]]: + user_group_indexes = [ + index + for index, group in enumerate(groups) + if group and group[0].role == "user" + ] + if len(user_group_indexes) <= max_turns: + return groups + first_kept_index = user_group_indexes[-max_turns] + return groups[first_kept_index:] + + @staticmethod + def _clone_message(message: Message) -> Message: + return message.model_copy(deep=True) + + @staticmethod + def _tool_call_id(tool_call: Any) -> str | None: + raw_id = ( + tool_call.get("id") + if isinstance(tool_call, dict) + else getattr(tool_call, "id", None) + ) + if raw_id is None: + return None + return str(raw_id) + + +class SubAgentRunner: + def __init__(self, session_manager: SubAgentSessionManager) -> None: + self._session_manager = session_manager + + async def run( + self, + *, + tool: Any, + run_context: ContextWrapper[AstrAgentContext], + event: Any, + ctx: Any, + provider_id: str, + input_: str | None, + image_urls: list[str], + system_prompt: str, + tools: Any, + begin_contexts: list[Message] | None, + max_steps: int, + tool_call_timeout: int, + stream: bool, + ) -> LLMResponse: + context_persistence = normalize_context_persistence( + getattr(tool, "context_persistence", None) + ) + key = self._session_manager.build_key(run_context, tool.agent.name) + if not context_persistence["enable"]: + self._session_manager.clear(key) + return await self._run_stateless( + event=event, + ctx=ctx, + provider_id=provider_id, + input_=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + begin_contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + + internal_run = getattr(ctx, "_run_tool_loop_agent_internal", None) + if internal_run is None: + logger.debug( + "Context._run_tool_loop_agent_internal is unavailable; falling " + "back to stateless SubAgent handoff." + ) + return await self._run_stateless( + event=event, + ctx=ctx, + provider_id=provider_id, + input_=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + begin_contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + + config_fingerprint = getattr(tool, "config_fingerprint", "") + lock = self._session_manager.get_lock(key) + async with lock: + persisted_contexts = self._session_manager.get_messages( + key, + ttl_seconds=context_persistence["ttl_seconds"], + config_fingerprint=config_fingerprint, + ) + contexts = ( + persisted_contexts if persisted_contexts is not None else begin_contexts + ) + + result = await internal_run( + event=event, + chat_provider_id=provider_id, + prompt=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + contexts=contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + self._session_manager.set_messages( + key, + result.run_context.messages, + config_fingerprint=config_fingerprint, + context_persistence=context_persistence, + ) + return result.llm_response + + async def _run_stateless( + self, + *, + event: Any, + ctx: Any, + provider_id: str, + input_: str | None, + image_urls: list[str], + system_prompt: str, + tools: Any, + begin_contexts: list[Message] | None, + max_steps: int, + tool_call_timeout: int, + stream: bool, + ) -> LLMResponse: + return await ctx.tool_loop_agent( + event=event, + chat_provider_id=provider_id, + prompt=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) diff --git a/astrbot/dashboard/routes/subagent.py b/astrbot/dashboard/routes/subagent.py index e3d77f73ad..e4275ffa0f 100644 --- a/astrbot/dashboard/routes/subagent.py +++ b/astrbot/dashboard/routes/subagent.py @@ -5,6 +5,7 @@ from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.subagent_runner import normalize_context_persistence from .route import Response, Route, RouteContext @@ -59,6 +60,9 @@ async def get_config(self): if isinstance(a, dict): a.setdefault("provider_id", None) a.setdefault("persona_id", None) + a["context_persistence"] = normalize_context_persistence( + a.get("context_persistence") + ) return jsonify(Response().ok(data=data).__dict__) except Exception as e: logger.error(traceback.format_exc()) @@ -70,6 +74,13 @@ async def update_config(self): if not isinstance(data, dict): return jsonify(Response().error("配置必须为 JSON 对象").__dict__) + if isinstance(data.get("agents"), list): + for a in data["agents"]: + if isinstance(a, dict): + a["context_persistence"] = normalize_context_persistence( + a.get("context_persistence") + ) + cfg = self.core_lifecycle.astrbot_config cfg["subagent_orchestrator"] = data diff --git a/dashboard/src/i18n/locales/en-US/features/subagent.json b/dashboard/src/i18n/locales/en-US/features/subagent.json index e9ea127f51..43b2ee40c4 100644 --- a/dashboard/src/i18n/locales/en-US/features/subagent.json +++ b/dashboard/src/i18n/locales/en-US/features/subagent.json @@ -63,6 +63,12 @@ "descriptionLabel": "Description for the main LLM (used to decide handoff)", "descriptionHint": "Shown to the main LLM as the transfer_to_* tool description—keep it short and clear." }, + "contextPersistence": { + "title": "Context persistence", + "subtitle": "Keep this SubAgent's private memory across repeated handoffs in the same session.", + "maxTurns": "Max turns", + "ttlSeconds": "Idle TTL seconds (-1 disables expiry)" + }, "messages": { "loadConfigFailed": "Failed to load config", "loadPersonaFailed": "Failed to load persona list", diff --git a/dashboard/src/i18n/locales/ru-RU/features/subagent.json b/dashboard/src/i18n/locales/ru-RU/features/subagent.json index 4f6b298b4d..5ac6a7610d 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/subagent.json +++ b/dashboard/src/i18n/locales/ru-RU/features/subagent.json @@ -1,4 +1,4 @@ -{ +{ "header": { "eyebrow": "Orchestration" }, @@ -63,6 +63,12 @@ "descriptionLabel": "Описание для основного LLM (используется для принятия решения о handoff)", "descriptionHint": "Отображается как описание инструмента transfer_to_* — будьте кратки и ясны." }, + "contextPersistence": { + "title": "Сохранение контекста", + "subtitle": "Сохранять приватную память этого SubAgent при повторных handoff в той же сессии.", + "maxTurns": "Максимум ходов", + "ttlSeconds": "TTL простоя, секунд (-1 отключает истечение)" + }, "messages": { "loadConfigFailed": "Не удалось загрузить конфигурацию", "loadPersonaFailed": "Не удалось загрузить список персонажей", diff --git a/dashboard/src/i18n/locales/zh-CN/features/subagent.json b/dashboard/src/i18n/locales/zh-CN/features/subagent.json index cd49ae432d..8981b5b0fd 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/subagent.json +++ b/dashboard/src/i18n/locales/zh-CN/features/subagent.json @@ -64,6 +64,12 @@ "descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff)", "descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。" }, + "contextPersistence": { + "title": "上下文持久化", + "subtitle": "在同一会话重复 handoff 时保留该子代理的私有记忆。", + "maxTurns": "最多轮数", + "ttlSeconds": "空闲过期秒数(-1 表示不过期)" + }, "messages": { "loadConfigFailed": "获取配置失败", "loadPersonaFailed": "获取 Persona 列表失败", diff --git a/dashboard/src/views/SubAgentPage.vue b/dashboard/src/views/SubAgentPage.vue index d3876ec4c8..07b2901682 100644 --- a/dashboard/src/views/SubAgentPage.vue +++ b/dashboard/src/views/SubAgentPage.vue @@ -184,6 +184,48 @@ auto-grow hide-details="auto" /> + +