diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..6b7806ab04 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -351,18 +351,37 @@ async def _execute_handoff( prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) agent_max_step = int(prov_settings.get("max_agent_step", 30)) stream = prov_settings.get("streaming_response", False) - llm_resp = await ctx.tool_loop_agent( - event=event, - chat_provider_id=prov_id, - prompt=input_, - image_urls=image_urls, - system_prompt=tool.agent.instructions, - tools=toolset, - contexts=contexts, - max_steps=agent_max_step, - tool_call_timeout=run_context.tool_call_timeout, - stream=stream, - ) + orchestrator = getattr(ctx, "subagent_orchestrator", None) + subagent_runner = getattr(orchestrator, "runner", None) + if subagent_runner is not None: + llm_resp = await subagent_runner.run( + tool=tool, + run_context=run_context, + event=event, + ctx=ctx, + provider_id=prov_id, + input_=input_, + image_urls=image_urls, + system_prompt=tool.agent.instructions, + tools=toolset, + begin_contexts=contexts, + max_steps=agent_max_step, + tool_call_timeout=run_context.tool_call_timeout, + stream=stream, + ) + else: + llm_resp = await ctx.tool_loop_agent( + event=event, + chat_provider_id=prov_id, + prompt=input_, + image_urls=image_urls, + system_prompt=tool.agent.instructions, + tools=toolset, + contexts=contexts, + max_steps=agent_max_step, + tool_call_timeout=run_context.tool_call_timeout, + stream=stream, + ) yield mcp.types.CallToolResult( content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] ) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 593bad9365..9f3f1bc5fa 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -3,12 +3,14 @@ import logging from asyncio import Queue from collections.abc import Awaitable, Callable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Protocol from deprecated import deprecated from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -53,6 +55,12 @@ RegisteredWebApi = tuple[str, WebApiHandler, list[str], str] +@dataclass +class _ToolLoopAgentRunResult: + llm_response: LLMResponse + run_context: ContextWrapper[Any] + + class PlatformManagerProtocol(Protocol): platform_insts: list[Platform] @@ -192,6 +200,36 @@ async def tool_loop_agent( ChatProviderNotFoundError: If the specified chat provider ID is not found Exception: For other errors during LLM generation """ + result = await self._run_tool_loop_agent_internal( + event=event, + chat_provider_id=chat_provider_id, + prompt=prompt, + image_urls=image_urls, + audio_urls=audio_urls, + tools=tools, + system_prompt=system_prompt, + contexts=contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + **kwargs, + ) + return result.llm_response + + async def _run_tool_loop_agent_internal( + self, + *, + event: AstrMessageEvent, + chat_provider_id: str, + prompt: str | None = None, + image_urls: list[str] | None = None, + audio_urls: list[str] | None = None, + tools: ToolSet | None = None, + system_prompt: str | None = None, + contexts: list[Message] | None = None, + max_steps: int = 30, + tool_call_timeout: int = 120, + **kwargs: Any, + ) -> _ToolLoopAgentRunResult: # Import here to avoid circular imports from astrbot.core.astr_agent_context import ( AgentContextWrapper, @@ -261,7 +299,10 @@ async def tool_loop_agent( llm_resp = agent_runner.get_final_llm_resp() if not llm_resp: raise Exception("Agent did not produce a final LLM response") - return llm_resp + return _ToolLoopAgentRunResult( + llm_response=llm_resp, + run_context=agent_runner.run_context, + ) async def get_current_chat_provider_id(self, umo: str) -> str: """获取当前使用的聊天模型 Provider ID。 diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index c6c595dfc9..e25e89e2f4 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -7,6 +7,12 @@ from astrbot.core.agent.agent import Agent from astrbot.core.agent.handoff import HandoffTool from astrbot.core.provider.func_tool_manager import FunctionToolManager +from astrbot.core.subagent_runner import ( + SubAgentRunner, + SubAgentSessionManager, + build_subagent_config_fingerprint, + normalize_context_persistence, +) if TYPE_CHECKING: from astrbot.core.persona_mgr import PersonaManager @@ -25,6 +31,8 @@ def __init__( self._tool_mgr = tool_mgr self._persona_mgr = persona_mgr self.handoffs: list[HandoffTool] = [] + self.session_manager = SubAgentSessionManager() + self.runner = SubAgentRunner(self.session_manager) async def reload_from_config(self, cfg: dict[str, Any]) -> None: from astrbot.core.astr_agent_context import AstrAgentContext @@ -35,6 +43,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: return handoffs: list[HandoffTool] = [] + persistent_agent_names: set[str] = set() for item in agents: if not isinstance(item, dict): continue @@ -60,6 +69,9 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: provider_id = item.get("provider_id") if provider_id is not None: provider_id = str(provider_id).strip() or None + context_persistence = normalize_context_persistence( + item.get("context_persistence") + ) tools = item.get("tools", []) begin_dialogs = None @@ -95,6 +107,19 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: # Optional per-subagent chat provider override. handoff.provider_id = provider_id + handoff.context_persistence = context_persistence + handoff.config_fingerprint = build_subagent_config_fingerprint( + { + "name": name, + "persona_id": persona_id, + "instructions": instructions, + "tools": tools, + "provider_id": provider_id, + "context_persistence": context_persistence, + } + ) + if context_persistence["enable"]: + persistent_agent_names.add(name) handoffs.append(handoff) @@ -102,3 +127,4 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: logger.info(f"Registered subagent handoff tool: {handoff.name}") self.handoffs = handoffs + self.session_manager.clear_except_agents(persistent_agent_names) diff --git a/astrbot/core/subagent_runner.py b/astrbot/core/subagent_runner.py new file mode 100644 index 0000000000..30f184e28c --- /dev/null +++ b/astrbot/core/subagent_runner.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from astrbot import logger +from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.provider.entities import LLMResponse + +if TYPE_CHECKING: + from astrbot.core.astr_agent_context import AstrAgentContext + +DEFAULT_CONTEXT_PERSISTENCE: dict[str, Any] = { + "enable": False, + "max_turns": 10, + "ttl_seconds": 3600, +} + + +def _positive_int(value: Any, default: int) -> int: + try: + parsed = int(float(value)) + except (TypeError, ValueError, OverflowError): + return default + return parsed if parsed > 0 else default + + +def _ttl_seconds(value: Any, default: int) -> int: + try: + parsed = int(float(value)) + except (TypeError, ValueError, OverflowError): + return default + if parsed == -1: + return parsed + return parsed if parsed > 0 else default + + +def normalize_context_persistence(raw: Any) -> dict[str, Any]: + defaults = DEFAULT_CONTEXT_PERSISTENCE + data = raw if isinstance(raw, dict) else {} + return { + "enable": bool(data.get("enable", defaults["enable"])), + "max_turns": _positive_int(data.get("max_turns"), defaults["max_turns"]), + "ttl_seconds": _ttl_seconds(data.get("ttl_seconds"), defaults["ttl_seconds"]), + } + + +def build_subagent_config_fingerprint(payload: dict[str, Any]) -> str: + stable_payload = json.dumps( + payload, sort_keys=True, ensure_ascii=False, default=str + ) + return hashlib.sha256(stable_payload.encode("utf-8")).hexdigest() + + +@dataclass +class _SubAgentContextRecord: + messages: list[Message] + last_used_at: float + config_fingerprint: str + + +class SubAgentSessionManager: + def __init__(self) -> None: + self._records: dict[tuple[str, str, str], _SubAgentContextRecord] = {} + self._locks: dict[tuple[str, str, str], asyncio.Lock] = {} + + def build_key( + self, + run_context: ContextWrapper[AstrAgentContext], + subagent_name: str, + ) -> tuple[str, str, str]: + event = run_context.context.event + unified_msg_origin = getattr(event, "unified_msg_origin", "") or "" + session_id = getattr(event, "session_id", "") or unified_msg_origin + return (unified_msg_origin, session_id, subagent_name) + + def get_lock(self, key: tuple[str, str, str]) -> asyncio.Lock: + lock = self._locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._locks[key] = lock + return lock + + def clear(self, key: tuple[str, str, str]) -> None: + self._records.pop(key, None) + lock = self._locks.get(key) + if lock is not None and not lock.locked(): + self._locks.pop(key, None) + + def clear_except_agents(self, agent_names: set[str]) -> None: + stale_keys = [key for key in self._records if key[2] not in agent_names] + for key in stale_keys: + self.clear(key) + + def get_messages( + self, + key: tuple[str, str, str], + *, + ttl_seconds: int, + config_fingerprint: str, + now: float | None = None, + ) -> list[Message] | None: + record = self._records.get(key) + if record is None: + return None + + now = time.monotonic() if now is None else now + if ttl_seconds != -1 and now - record.last_used_at > ttl_seconds: + self.clear(key) + return None + if record.config_fingerprint != config_fingerprint: + self.clear(key) + return None + return [self._clone_message(message) for message in record.messages] + + def set_messages( + self, + key: tuple[str, str, str], + messages: list[Message], + *, + config_fingerprint: str, + context_persistence: dict[str, Any], + now: float | None = None, + ) -> None: + trimmed = self._trim_messages( + messages, + max_turns=context_persistence["max_turns"], + ) + self._records[key] = _SubAgentContextRecord( + messages=trimmed, + last_used_at=time.monotonic() if now is None else now, + config_fingerprint=config_fingerprint, + ) + + def _trim_messages( + self, + messages: list[Message], + *, + max_turns: int, + ) -> list[Message]: + groups = self._group_messages(messages) + groups = self._trim_groups_by_turns(groups, max_turns) + return [message for group in groups for message in group] + + def _group_messages(self, messages: list[Message]) -> list[list[Message]]: + groups: list[list[Message]] = [] + index = 0 + while index < len(messages): + message = messages[index] + if message.role in {"system", "_checkpoint"}: + index += 1 + continue + if message.role == "tool": + index += 1 + continue + + cloned = self._clone_message(message) + group = [cloned] + if message.role == "assistant" and message.tool_calls: + expected_ids = { + tool_call_id + for tool_call in message.tool_calls + if (tool_call_id := self._tool_call_id(tool_call)) is not None + } + next_index = index + 1 + while next_index < len(messages): + next_message = messages[next_index] + if ( + next_message.role == "tool" + and next_message.tool_call_id in expected_ids + ): + group.append(self._clone_message(next_message)) + next_index += 1 + continue + break + index = next_index + else: + index += 1 + groups.append(group) + return groups + + def _trim_groups_by_turns( + self, groups: list[list[Message]], max_turns: int + ) -> list[list[Message]]: + user_group_indexes = [ + index + for index, group in enumerate(groups) + if group and group[0].role == "user" + ] + if len(user_group_indexes) <= max_turns: + return groups + first_kept_index = user_group_indexes[-max_turns] + return groups[first_kept_index:] + + @staticmethod + def _clone_message(message: Message) -> Message: + return message.model_copy(deep=True) + + @staticmethod + def _tool_call_id(tool_call: Any) -> str | None: + raw_id = ( + tool_call.get("id") + if isinstance(tool_call, dict) + else getattr(tool_call, "id", None) + ) + if raw_id is None: + return None + return str(raw_id) + + +class SubAgentRunner: + def __init__(self, session_manager: SubAgentSessionManager) -> None: + self._session_manager = session_manager + + async def run( + self, + *, + tool: Any, + run_context: ContextWrapper[AstrAgentContext], + event: Any, + ctx: Any, + provider_id: str, + input_: str | None, + image_urls: list[str], + system_prompt: str, + tools: Any, + begin_contexts: list[Message] | None, + max_steps: int, + tool_call_timeout: int, + stream: bool, + ) -> LLMResponse: + context_persistence = normalize_context_persistence( + getattr(tool, "context_persistence", None) + ) + key = self._session_manager.build_key(run_context, tool.agent.name) + if not context_persistence["enable"]: + self._session_manager.clear(key) + return await self._run_stateless( + event=event, + ctx=ctx, + provider_id=provider_id, + input_=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + begin_contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + + internal_run = getattr(ctx, "_run_tool_loop_agent_internal", None) + if internal_run is None: + logger.debug( + "Context._run_tool_loop_agent_internal is unavailable; falling " + "back to stateless SubAgent handoff." + ) + return await self._run_stateless( + event=event, + ctx=ctx, + provider_id=provider_id, + input_=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + begin_contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + + config_fingerprint = getattr(tool, "config_fingerprint", "") + lock = self._session_manager.get_lock(key) + async with lock: + persisted_contexts = self._session_manager.get_messages( + key, + ttl_seconds=context_persistence["ttl_seconds"], + config_fingerprint=config_fingerprint, + ) + contexts = ( + persisted_contexts if persisted_contexts is not None else begin_contexts + ) + + result = await internal_run( + event=event, + chat_provider_id=provider_id, + prompt=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + contexts=contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) + self._session_manager.set_messages( + key, + result.run_context.messages, + config_fingerprint=config_fingerprint, + context_persistence=context_persistence, + ) + return result.llm_response + + async def _run_stateless( + self, + *, + event: Any, + ctx: Any, + provider_id: str, + input_: str | None, + image_urls: list[str], + system_prompt: str, + tools: Any, + begin_contexts: list[Message] | None, + max_steps: int, + tool_call_timeout: int, + stream: bool, + ) -> LLMResponse: + return await ctx.tool_loop_agent( + event=event, + chat_provider_id=provider_id, + prompt=input_, + image_urls=image_urls, + system_prompt=system_prompt, + tools=tools, + contexts=begin_contexts, + max_steps=max_steps, + tool_call_timeout=tool_call_timeout, + stream=stream, + ) diff --git a/astrbot/dashboard/routes/subagent.py b/astrbot/dashboard/routes/subagent.py index e3d77f73ad..e4275ffa0f 100644 --- a/astrbot/dashboard/routes/subagent.py +++ b/astrbot/dashboard/routes/subagent.py @@ -5,6 +5,7 @@ from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.subagent_runner import normalize_context_persistence from .route import Response, Route, RouteContext @@ -59,6 +60,9 @@ async def get_config(self): if isinstance(a, dict): a.setdefault("provider_id", None) a.setdefault("persona_id", None) + a["context_persistence"] = normalize_context_persistence( + a.get("context_persistence") + ) return jsonify(Response().ok(data=data).__dict__) except Exception as e: logger.error(traceback.format_exc()) @@ -70,6 +74,13 @@ async def update_config(self): if not isinstance(data, dict): return jsonify(Response().error("配置必须为 JSON 对象").__dict__) + if isinstance(data.get("agents"), list): + for a in data["agents"]: + if isinstance(a, dict): + a["context_persistence"] = normalize_context_persistence( + a.get("context_persistence") + ) + cfg = self.core_lifecycle.astrbot_config cfg["subagent_orchestrator"] = data diff --git a/dashboard/src/i18n/locales/en-US/features/subagent.json b/dashboard/src/i18n/locales/en-US/features/subagent.json index e9ea127f51..43b2ee40c4 100644 --- a/dashboard/src/i18n/locales/en-US/features/subagent.json +++ b/dashboard/src/i18n/locales/en-US/features/subagent.json @@ -63,6 +63,12 @@ "descriptionLabel": "Description for the main LLM (used to decide handoff)", "descriptionHint": "Shown to the main LLM as the transfer_to_* tool description—keep it short and clear." }, + "contextPersistence": { + "title": "Context persistence", + "subtitle": "Keep this SubAgent's private memory across repeated handoffs in the same session.", + "maxTurns": "Max turns", + "ttlSeconds": "Idle TTL seconds (-1 disables expiry)" + }, "messages": { "loadConfigFailed": "Failed to load config", "loadPersonaFailed": "Failed to load persona list", diff --git a/dashboard/src/i18n/locales/ru-RU/features/subagent.json b/dashboard/src/i18n/locales/ru-RU/features/subagent.json index 4f6b298b4d..5ac6a7610d 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/subagent.json +++ b/dashboard/src/i18n/locales/ru-RU/features/subagent.json @@ -1,4 +1,4 @@ -{ +{ "header": { "eyebrow": "Orchestration" }, @@ -63,6 +63,12 @@ "descriptionLabel": "Описание для основного LLM (используется для принятия решения о handoff)", "descriptionHint": "Отображается как описание инструмента transfer_to_* — будьте кратки и ясны." }, + "contextPersistence": { + "title": "Сохранение контекста", + "subtitle": "Сохранять приватную память этого SubAgent при повторных handoff в той же сессии.", + "maxTurns": "Максимум ходов", + "ttlSeconds": "TTL простоя, секунд (-1 отключает истечение)" + }, "messages": { "loadConfigFailed": "Не удалось загрузить конфигурацию", "loadPersonaFailed": "Не удалось загрузить список персонажей", diff --git a/dashboard/src/i18n/locales/zh-CN/features/subagent.json b/dashboard/src/i18n/locales/zh-CN/features/subagent.json index cd49ae432d..8981b5b0fd 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/subagent.json +++ b/dashboard/src/i18n/locales/zh-CN/features/subagent.json @@ -64,6 +64,12 @@ "descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff)", "descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。" }, + "contextPersistence": { + "title": "上下文持久化", + "subtitle": "在同一会话重复 handoff 时保留该子代理的私有记忆。", + "maxTurns": "最多轮数", + "ttlSeconds": "空闲过期秒数(-1 表示不过期)" + }, "messages": { "loadConfigFailed": "获取配置失败", "loadPersonaFailed": "获取 Persona 列表失败", diff --git a/dashboard/src/views/SubAgentPage.vue b/dashboard/src/views/SubAgentPage.vue index d3876ec4c8..07b2901682 100644 --- a/dashboard/src/views/SubAgentPage.vue +++ b/dashboard/src/views/SubAgentPage.vue @@ -184,6 +184,48 @@ auto-grow hide-details="auto" /> + +
+
+
+
{{ tm('contextPersistence.title') }}
+
{{ tm('contextPersistence.subtitle') }}
+
+ +
+ + +
+ + +
+
+
@@ -220,6 +262,12 @@ import ProviderSelector from '@/components/shared/ProviderSelector.vue' import { useModuleI18n } from '@/i18n/composables' import { askForConfirmation, useConfirmDialog } from '@/utils/confirmDialog' +type ContextPersistenceConfig = { + enable: boolean + max_turns: number + ttl_seconds: number +} + type SubAgentItem = { __key: string name: string @@ -227,6 +275,7 @@ type SubAgentItem = { public_description: string enabled: boolean provider_id?: string + context_persistence: ContextPersistenceConfig } type SubAgentConfig = { @@ -268,6 +317,35 @@ const mainStateDescription = computed(() => const hasUnsavedChanges = computed(() => hasLoaded.value && serializeConfig(cfg.value) !== initialSnapshot.value) +function normalizePositiveNumber(value: unknown, fallback: number): number { + const parsed = Number.parseInt(String(value), 10) + return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback +} + +function normalizeTtlSeconds(value: unknown, fallback: number): number { + const parsed = Number.parseInt(String(value), 10) + if (!Number.isFinite(parsed)) return fallback + if (parsed === -1) return parsed + return parsed > 0 ? parsed : fallback +} + +function defaultContextPersistence(): ContextPersistenceConfig { + return { + enable: false, + max_turns: 10, + ttl_seconds: 3600 + } +} + +function normalizeContextPersistence(raw: any): ContextPersistenceConfig { + const defaults = defaultContextPersistence() + return { + enable: !!raw?.enable, + max_turns: normalizePositiveNumber(raw?.max_turns, defaults.max_turns), + ttl_seconds: normalizeTtlSeconds(raw?.ttl_seconds, defaults.ttl_seconds) + } +} + function normalizeConfig(raw: any): SubAgentConfig { const main_enable = !!raw?.main_enable const remove_main_duplicate_tools = !!raw?.remove_main_duplicate_tools @@ -279,7 +357,8 @@ function normalizeConfig(raw: any): SubAgentConfig { persona_id: (a?.persona_id ?? '').toString(), public_description: (a?.public_description ?? '').toString(), enabled: a?.enabled !== false, - provider_id: (a?.provider_id ?? undefined) as string | undefined + provider_id: (a?.provider_id ?? undefined) as string | undefined, + context_persistence: normalizeContextPersistence(a?.context_persistence) })) return { main_enable, remove_main_duplicate_tools, agents } @@ -294,7 +373,8 @@ function serializeConfig(config: SubAgentConfig): string { persona_id: agent.persona_id, public_description: agent.public_description, enabled: agent.enabled, - provider_id: agent.provider_id ?? null + provider_id: agent.provider_id ?? null, + context_persistence: normalizeContextPersistence(agent.context_persistence) })) }) } @@ -326,7 +406,8 @@ function addAgent() { persona_id: '', public_description: '', enabled: true, - provider_id: undefined + provider_id: undefined, + context_persistence: defaultContextPersistence() }) expandedAgents.value[key] = false } @@ -386,7 +467,8 @@ async function save() { persona_id: agent.persona_id, public_description: agent.public_description, enabled: agent.enabled, - provider_id: agent.provider_id + provider_id: agent.provider_id, + context_persistence: normalizeContextPersistence(agent.context_persistence) })) } @@ -606,6 +688,20 @@ onBeforeRouteLeave(async () => { background: transparent; } +.context-persistence-panel { + border: 1px solid var(--dashboard-border); + border-radius: 12px; + padding: 16px; + background: rgba(var(--v-theme-primary), 0.02); +} + +.context-persistence-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 14px; + margin-top: 16px; +} + .persona-preview-wrap { min-height: 320px; } @@ -622,5 +718,9 @@ onBeforeRouteLeave(async () => { flex-direction: column; align-items: flex-start; } + + .context-persistence-grid { + grid-template-columns: 1fr; + } } diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index ef6edc899d..c253f152e8 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1251,6 +1251,11 @@ async def test_subagent_config_accepts_default_persona( get_data = await get_response.get_json() assert get_data["status"] == "ok" assert get_data["data"]["agents"][0]["persona_id"] == "default" + assert get_data["data"]["agents"][0]["context_persistence"] == { + "enable": False, + "max_turns": 10, + "ttl_seconds": 3600, + } finally: await test_client.post( "/api/subagent/config", diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..0f353748fd 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -1,16 +1,25 @@ +import asyncio from types import SimpleNamespace import mcp import pytest +from astrbot.core.agent.message import Message from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image +from astrbot.core.subagent_orchestrator import SubAgentOrchestrator class _DummyEvent: - def __init__(self, message_components: list[object] | None = None) -> None: - self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" + def __init__( + self, + message_components: list[object] | None = None, + unified_msg_origin: str = "webchat:FriendMessage:webchat!user!session", + session_id: str = "webchat!user!session", + ) -> None: + self.unified_msg_origin = unified_msg_origin + self.session_id = session_id self.message_obj = SimpleNamespace(message=message_components or []) def get_extra(self, _key: str): @@ -321,6 +330,174 @@ async def _fake_tool_loop_agent(**kwargs): assert captured["tool_call_timeout"] == 120 +def _build_persistent_handoff_tool(): + return SimpleNamespace( + name="transfer_to_subagent", + provider_id=None, + context_persistence={ + "enable": True, + "max_turns": 10, + "ttl_seconds": 3600, + }, + config_fingerprint="fingerprint", + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[{"role": "user", "content": "begin"}], + run_hooks=None, + ), + ) + + +@pytest.mark.asyncio +async def test_execute_handoff_persists_private_subagent_context(): + captured_contexts: list[list[Message] | None] = [] + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_run_internal(**kwargs): + contexts = kwargs.get("contexts") + captured_contexts.append(contexts) + messages = list(contexts or []) + messages.extend( + [ + Message(role="user", content=kwargs["prompt"]), + Message(role="assistant", content=f"reply {len(captured_contexts)}"), + ] + ) + return SimpleNamespace( + llm_response=SimpleNamespace( + completion_text=f"reply {len(captured_contexts)}" + ), + run_context=SimpleNamespace(messages=messages), + ) + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + _run_tool_loop_agent_internal=_fake_run_internal, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + context.subagent_orchestrator = SubAgentOrchestrator( + tool_mgr=SimpleNamespace(), persona_mgr=SimpleNamespace() + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = _build_persistent_handoff_tool() + + for prompt in ("remember alpha", "what did I say"): + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input=prompt, + image_urls=[], + ): + results.append(result) + assert len(results) == 1 + + assert captured_contexts[0][0].content == "begin" + assert [message.content for message in captured_contexts[1]] == [ + "begin", + "remember alpha", + "reply 1", + ] + + +@pytest.mark.asyncio +async def test_execute_handoff_context_persistence_disabled_uses_plain_tool_loop(): + captured: dict = {} + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_tool_loop_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + context.subagent_orchestrator = SubAgentOrchestrator( + tool_mgr=SimpleNamespace(), persona_mgr=SimpleNamespace() + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = _build_persistent_handoff_tool() + tool.context_persistence = {"enable": False} + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input="hello", + image_urls=[], + ): + results.append(result) + + assert len(results) == 1 + assert captured["prompt"] == "hello" + assert captured["contexts"][0].content == "begin" + + +@pytest.mark.asyncio +async def test_execute_handoff_context_persistence_lock_serializes_same_key(): + active = 0 + max_active = 0 + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_run_internal(**kwargs): + nonlocal active, max_active + active += 1 + max_active = max(max_active, active) + await asyncio.sleep(0.01) + messages = list(kwargs.get("contexts") or []) + messages.extend( + [ + Message(role="user", content=kwargs["prompt"]), + Message(role="assistant", content="ok"), + ] + ) + active -= 1 + return SimpleNamespace( + llm_response=SimpleNamespace(completion_text="ok"), + run_context=SimpleNamespace(messages=messages), + ) + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + _run_tool_loop_agent_internal=_fake_run_internal, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + context.subagent_orchestrator = SubAgentOrchestrator( + tool_mgr=SimpleNamespace(), persona_mgr=SimpleNamespace() + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = _build_persistent_handoff_tool() + + async def _run_once(prompt: str): + async for _ in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input=prompt, + image_urls=[], + ): + pass + + await asyncio.gather(_run_once("one"), _run_once("two")) + + assert max_active == 1 + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_subagent_orchestrator.py b/tests/unit/test_subagent_orchestrator.py index 9befac8872..d38188562e 100644 --- a/tests/unit/test_subagent_orchestrator.py +++ b/tests/unit/test_subagent_orchestrator.py @@ -1,9 +1,16 @@ from copy import deepcopy +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.subagent_orchestrator import SubAgentOrchestrator +from astrbot.core.subagent_runner import ( + SubAgentSessionManager, + normalize_context_persistence, +) def _build_cfg(agent_overrides: dict) -> dict: @@ -108,3 +115,245 @@ async def test_reload_from_config_tool_normalization(raw_tools, expected_tools): handoff = orchestrator.handoffs[0] assert handoff.agent.tools == expected_tools + + +def test_normalize_context_persistence_defaults_ttl_to_one_hour(): + assert normalize_context_persistence(None) == { + "enable": False, + "max_turns": 10, + "ttl_seconds": 3600, + } + + +def test_normalize_context_persistence_allows_ttl_without_expiry(): + assert normalize_context_persistence({"ttl_seconds": -1})["ttl_seconds"] == -1 + + +def test_normalize_context_persistence_accepts_float_strings(): + assert normalize_context_persistence( + {"max_turns": "3.0", "ttl_seconds": "3600.0"} + ) == { + "enable": False, + "max_turns": 3, + "ttl_seconds": 3600, + } + + +@pytest.mark.asyncio +async def test_reload_from_config_binds_context_persistence_defaults(): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + persona_mgr.get_persona_v3_by_id.return_value = None + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + await orchestrator.reload_from_config(_build_cfg({})) + + handoff = orchestrator.handoffs[0] + assert handoff.context_persistence["enable"] is False + assert handoff.context_persistence["ttl_seconds"] == 3600 + assert isinstance(handoff.config_fingerprint, str) + + +def test_subagent_session_manager_expires_by_ttl_and_fingerprint(): + manager = SubAgentSessionManager() + key = ("umo", "session", "planner") + config = normalize_context_persistence({"enable": True}) + messages = [Message(role="user", content="remember me")] + + manager.set_messages( + key, + messages, + config_fingerprint="a", + context_persistence=config, + now=100.0, + ) + + assert manager.get_messages( + key, + ttl_seconds=3600, + config_fingerprint="a", + now=200.0, + ) + assert ( + manager.get_messages( + key, + ttl_seconds=10, + config_fingerprint="a", + now=200.0, + ) + is None + ) + + manager.set_messages( + key, + messages, + config_fingerprint="a", + context_persistence=config, + now=300.0, + ) + assert manager.get_messages( + key, + ttl_seconds=-1, + config_fingerprint="a", + now=999999.0, + ) + + manager.set_messages( + key, + messages, + config_fingerprint="a", + context_persistence=config, + now=300.0, + ) + assert ( + manager.get_messages( + key, + ttl_seconds=3600, + config_fingerprint="b", + now=301.0, + ) + is None + ) + + +def test_subagent_session_manager_trim_preserves_tool_call_pairs(): + manager = SubAgentSessionManager() + tool_call = { + "type": "function", + "id": "call-1", + "function": {"name": "lookup", "arguments": "{}"}, + } + messages = [ + Message(role="system", content="system"), + Message(role="user", content="old"), + Message(role="assistant", content="old reply"), + Message(role="user", content="new"), + Message(role="assistant", content=None, tool_calls=[tool_call]), + Message(role="tool", content="tool result", tool_call_id="call-1"), + Message(role="assistant", content="final"), + ] + key = ("umo", "session", "planner") + + manager.set_messages( + key, + messages, + config_fingerprint="fp", + context_persistence={ + "enable": True, + "max_turns": 1, + "ttl_seconds": 3600, + }, + now=100.0, + ) + + trimmed = manager.get_messages( + key, + ttl_seconds=3600, + config_fingerprint="fp", + now=101.0, + ) + assert [message.role for message in trimmed] == [ + "user", + "assistant", + "tool", + "assistant", + ] + assert trimmed[1].tool_calls + assert trimmed[2].tool_call_id == "call-1" + + +def test_subagent_session_manager_matches_non_string_dict_tool_call_id(): + manager = SubAgentSessionManager() + key = ("umo", "session", "planner") + messages = [ + Message( + role="assistant", + content=None, + tool_calls=[ + { + "type": "function", + "id": 123, + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="tool result", tool_call_id="123"), + ] + + manager.set_messages( + key, + messages, + config_fingerprint="fp", + context_persistence={ + "enable": True, + "max_turns": 10, + "ttl_seconds": 3600, + }, + now=100.0, + ) + + trimmed = manager.get_messages( + key, + ttl_seconds=3600, + config_fingerprint="fp", + 现在=101.0, + ) + + assert [message.role for message in trimmed] == ["assistant", "tool"] + assert trimmed[1].tool_call_id == "123" + + +def test_subagent_session_manager_ignores_missing_tool_call_id(): + manager = SubAgentSessionManager() + groups = manager._group_messages( + [ + Message( + role="assistant", + content=None, + tool_calls=[ + { + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="tool result", tool_call_id="123"), + ] + ) + + assert len(groups) == 1 + assert [message.role for message in groups[0]] == ["assistant"] + + +@pytest.mark.asyncio +async def test_subagent_session_manager_clear_removes_only_idle_locks(): + manager = SubAgentSessionManager() + key = ("umo", "session", "planner") + idle_lock = manager.get_lock(key) + + manager.clear(key) + + assert key not in manager._locks + held_lock = manager.get_lock(key) + + async with held_lock: + manager.clear(key) + assert manager.get_lock(key) is held_lock + + assert manager.get_lock(key) is held_lock + assert held_lock is not idle_lock + + +def test_subagent_session_manager_key_uses_session_and_agent(): + event = MagicMock() + event.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" + event.session_id = "webchat!user!session" + run_context = ContextWrapper(context=SimpleNamespace(event=event)) + + key = SubAgentSessionManager().build_key(run_context, "planner") + + assert key == ( + "webchat:FriendMessage:webchat!user!session", + "webchat!user!session", + "planner", + )