From eed13e23191e3808f219ae904cffe4dd2330fef8 Mon Sep 17 00:00:00 2001 From: BernardXiong Date: Sun, 10 May 2026 17:43:52 +0800 Subject: [PATCH] Integrate eagent into env Add the eagent runtime package as an Env-integrated agent feature exposed through the agent and rt-env agent entrypoints. Wire eagent into setup metadata, dependencies, install scripts, and Env activation so the agent is available from the Env virtual environment. Move user-level agent configuration and runtime state under the Env root, support project and user skill discovery, and keep reload behavior working from the integrated Env entrypoint. Refactor the TUI into clearer components for status bar rendering and agent profile picking, with simpler status output and wider resume completions. --- README.md | 83 +- eagent/__init__.py | 4 + eagent/cli.py | 898 +++++++++++++++++ eagent/commands/__init__.py | 5 + eagent/commands/registry.py | 506 ++++++++++ eagent/context/__init__.py | 64 ++ eagent/context/agent_config.py | 126 +++ eagent/context/compaction.py | 73 ++ eagent/context/git_context.py | 65 ++ eagent/context/memory.py | 55 ++ eagent/context/post_compact.py | 15 + eagent/context/session_store.py | 248 +++++ eagent/context/token_counting.py | 66 ++ eagent/core/__init__.py | 6 + eagent/core/agent_loop.py | 415 ++++++++ eagent/core/api_client.py | 336 +++++++ eagent/core/errors.py | 117 +++ eagent/core/streaming_executor.py | 286 ++++++ eagent/core/types.py | 325 +++++++ eagent/files/__init__.py | 22 + eagent/files/atomic_write.py | 48 + eagent/files/cache.py | 96 ++ eagent/files/history.py | 138 +++ eagent/files/utils.py | 55 ++ eagent/headless.py | 249 +++++ eagent/hooks/__init__.py | 5 + eagent/hooks/runtime.py | 819 ++++++++++++++++ eagent/mcp/__init__.py | 26 + eagent/mcp/client.py | 283 ++++++ eagent/mcp/config.py | 75 ++ eagent/mcp/manager.py | 68 ++ eagent/mcp/types.py | 49 + eagent/paths.py | 21 + eagent/permissions/__init__.py | 30 + eagent/permissions/engine.py | 194 ++++ eagent/permissions/modes.py | 37 + eagent/permissions/path_validation.py | 82 ++ eagent/permissions/rules.py | 129 +++ eagent/prompt/__init__.py | 1 + eagent/prompt/agent_prompts.py | 17 + eagent/prompt/cache_boundary.py | 34 + eagent/prompt/compact_prompt.py | 51 + eagent/prompt/system_prompt.py | 55 ++ eagent/reload.py | 27 + eagent/skills/__init__.py | 32 + eagent/skills/loader.py | 267 +++++ eagent/skills/skill_tool.py | 142 +++ eagent/skills/types.py | 34 + eagent/tools/__init__.py | 62 ++ eagent/tools/agent_tool.py | 67 ++ eagent/tools/ask.py | 36 + eagent/tools/bash.py | 138 +++ eagent/tools/bash_readonly.py | 281 ++++++ eagent/tools/edit.py | 167 ++++ eagent/tools/glob_tool.py | 58 ++ eagent/tools/grep_tool.py | 84 ++ eagent/tools/mcp_wrapper.py | 55 ++ eagent/tools/notebook_edit.py | 77 ++ eagent/tools/plan_mode.py | 48 + eagent/tools/read.py | 130 +++ eagent/tools/registry.py | 139 +++ eagent/tools/todo.py | 92 ++ eagent/tools/web_fetch.py | 64 ++ eagent/tools/web_search.py | 238 +++++ eagent/tools/write.py | 79 ++ eagent/tui/__init__.py | 6 + eagent/tui/agent_picker.py | 80 ++ eagent/tui/app.py | 1295 +++++++++++++++++++++++++ eagent/tui/status_bar.py | 138 +++ eagent/tui/styles.py | 42 + eagent/utils/__init__.py | 27 + eagent/utils/completer.py | 317 ++++++ eagent/utils/cost.py | 40 + eagent/utils/format.py | 35 + eagent/utils/process.py | 54 ++ eagent/utils/streaming.py | 40 + env.py | 25 +- env.sh | 20 +- install_macos.sh | 7 + install_ubuntu.sh | 2 +- install_windows.ps1 | 10 + setup.py | 25 +- touch_env.sh | 3 +- 83 files changed, 10754 insertions(+), 6 deletions(-) create mode 100644 eagent/__init__.py create mode 100644 eagent/cli.py create mode 100644 eagent/commands/__init__.py create mode 100644 eagent/commands/registry.py create mode 100644 eagent/context/__init__.py create mode 100644 eagent/context/agent_config.py create mode 100644 eagent/context/compaction.py create mode 100644 eagent/context/git_context.py create mode 100644 eagent/context/memory.py create mode 100644 eagent/context/post_compact.py create mode 100644 eagent/context/session_store.py create mode 100644 eagent/context/token_counting.py create mode 100644 eagent/core/__init__.py create mode 100644 eagent/core/agent_loop.py create mode 100644 eagent/core/api_client.py create mode 100644 eagent/core/errors.py create mode 100644 eagent/core/streaming_executor.py create mode 100644 eagent/core/types.py create mode 100644 eagent/files/__init__.py create mode 100644 eagent/files/atomic_write.py create mode 100644 eagent/files/cache.py create mode 100644 eagent/files/history.py create mode 100644 eagent/files/utils.py create mode 100644 eagent/headless.py create mode 100644 eagent/hooks/__init__.py create mode 100644 eagent/hooks/runtime.py create mode 100644 eagent/mcp/__init__.py create mode 100644 eagent/mcp/client.py create mode 100644 eagent/mcp/config.py create mode 100644 eagent/mcp/manager.py create mode 100644 eagent/mcp/types.py create mode 100644 eagent/paths.py create mode 100644 eagent/permissions/__init__.py create mode 100644 eagent/permissions/engine.py create mode 100644 eagent/permissions/modes.py create mode 100644 eagent/permissions/path_validation.py create mode 100644 eagent/permissions/rules.py create mode 100644 eagent/prompt/__init__.py create mode 100644 eagent/prompt/agent_prompts.py create mode 100644 eagent/prompt/cache_boundary.py create mode 100644 eagent/prompt/compact_prompt.py create mode 100644 eagent/prompt/system_prompt.py create mode 100644 eagent/reload.py create mode 100644 eagent/skills/__init__.py create mode 100644 eagent/skills/loader.py create mode 100644 eagent/skills/skill_tool.py create mode 100644 eagent/skills/types.py create mode 100644 eagent/tools/__init__.py create mode 100644 eagent/tools/agent_tool.py create mode 100644 eagent/tools/ask.py create mode 100644 eagent/tools/bash.py create mode 100644 eagent/tools/bash_readonly.py create mode 100644 eagent/tools/edit.py create mode 100644 eagent/tools/glob_tool.py create mode 100644 eagent/tools/grep_tool.py create mode 100644 eagent/tools/mcp_wrapper.py create mode 100644 eagent/tools/notebook_edit.py create mode 100644 eagent/tools/plan_mode.py create mode 100644 eagent/tools/read.py create mode 100644 eagent/tools/registry.py create mode 100644 eagent/tools/todo.py create mode 100644 eagent/tools/web_fetch.py create mode 100644 eagent/tools/web_search.py create mode 100644 eagent/tools/write.py create mode 100644 eagent/tui/__init__.py create mode 100644 eagent/tui/agent_picker.py create mode 100644 eagent/tui/app.py create mode 100644 eagent/tui/status_bar.py create mode 100644 eagent/tui/styles.py create mode 100644 eagent/utils/__init__.py create mode 100644 eagent/utils/completer.py create mode 100644 eagent/utils/cost.py create mode 100644 eagent/utils/format.py create mode 100644 eagent/utils/process.py create mode 100644 eagent/utils/streaming.py diff --git a/README.md b/README.md index 20cd129..e8f8a2f 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ ### Install Env +EnvAgent requires Python 3.11 or newer. + ``` wget https://raw.githubusercontent.com/RT-Thread/env/master/install_ubuntu.sh chmod 777 install_ubuntu.sh @@ -44,6 +46,57 @@ or PLAN B: open `~/.bashrc` file, and attach the command `source ~/.env/env.sh` Please see: +### Use EnvAgent + +Env includes EnvAgent as an integrated AI assistant feature. After activating Env: + +```bash +source ~/.env/env.sh +agent +``` + +You can also run it through the Env command dispatcher: + +```bash +rt-env agent +rt-env agent --prompt "help me inspect this BSP" +``` + +EnvAgent reads model profiles from `~/.env/agent.json`, or falls back to +`ANTHROPIC_API_KEY` when no profile is configured. A typical `agent.json` looks +like this: + +```json +{ + "active": "Kimi-K2", + "profiles": [ + { + "name": "Kimi-K2", + "provider": "kimi", + "model": "kimi-k2-2026", + "key": "sk-xxxx", + "base_url": "https://api.moonshot.cn/anthropic" + } + ] +} +``` + +Without `--prompt`, EnvAgent opens the full-screen TUI. Inside the TUI, press +`Enter` to send, `Alt+Enter` for a newline, `/agent` to switch model profiles, +and `Ctrl-D` to exit. + +EnvAgent stores user-level configuration and runtime state under the Env root +directory, normally `~/.env`. For example, model profiles are stored in +`~/.env/agent.json`, sessions in `~/.env/sessions`, and user hooks/settings in +`~/.env/hooks` and `~/.env/settings.json`. + +Skills are loaded in this priority order, with the first skill name winning when +duplicates exist: + +1. Project skills: `/.agents/skills` +2. User agent skills: `~/.agents/skills` +3. Env skills: `~/.env/skills` + ## Usage under Windows Tested on the following version of PowerShell: @@ -55,6 +108,8 @@ Tested on the following version of PowerShell: 您需要以管理员身份运行 PowerShell 来设置执行。(You need to run PowerShell as an administrator to set up execution.) +EnvAgent 需要 Python 3.11 或更高版本。 + 在 PowerShell 中执行(Execute the command in PowerShell): ```powershell @@ -83,10 +138,36 @@ set-executionpolicy remotesigned 方案 B (推荐):打开 `C:\Users\user\Documents\WindowsPowerShell`,如果没有`WindowsPowerShell`则新建该文件夹。新建文件 `Microsoft.PowerShell_profile.ps1`,然后写入 `~/.env/env.ps1` 内容即可,它将在你重启 PowerShell 时自动执行,无需再执行方案 A 中的命令。(or PLAN B (recommended): Open `C:\Users\user\Documents\WindowsPowerShell` and create a new file `Microsoft.PowerShell_profile.ps1`. Then write `~/.env/env.ps1` to the file. It will be executed automatically when you restart PowerShell, without having to execute the command in scenario A.) +### Use EnvAgent + +激活 Env 后可以直接进入 EnvAgent: + +```powershell +~/.env/env.ps1 +agent +``` + +也可以通过 Env 命令调用: + +```powershell +rt-env agent +rt-env agent --prompt "help me inspect this BSP" +``` + +EnvAgent 的模型配置文件为 `~/.env/agent.json`,格式与 Linux/macOS 相同。 + +EnvAgent 的用户级配置和运行状态与 Env 工具一致放在 `~/.env` 下,例如 +`~/.env/agent.json`、`~/.env/sessions`、`~/.env/hooks` 和 +`~/.env/settings.json`。Skills 按以下优先级加载,同名 skill 以先加载者为准: + +1. 工程目录:`/.agents/skills` +2. 用户 Agent 目录:`~/.agents/skills` +3. Env 目录:`~/.env/skills` + ### 常见问题 对于中国大陆用户,请注意首次激活 Env 时可能出现错误,这可能是当前网络下使用的镜像(默认清华源)连接失败,修复方法: 1. 再次进入安装 Env 的目录,运行`.\install_windows.ps1 --gitee`重新安装,并在**安装完成后不要激活 Env**。 2. 打开 `~/.env/env.ps1` 文件,修改 `python -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple` 和 `pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "$PSScriptRoot\tools\scriptse` 中的镜像地址 `https://pypi.tuna.tsinghua.edu.cn/simple` 为其他可用的镜像。 -3. 激活 Env。 \ No newline at end of file +3. 激活 Env。 diff --git a/eagent/__init__.py b/eagent/__init__.py new file mode 100644 index 0000000..628ebcf --- /dev/null +++ b/eagent/__init__.py @@ -0,0 +1,4 @@ +"""eagent package.""" + +__all__ = ["__version__"] +__version__ = "0.1.0" diff --git a/eagent/cli.py b/eagent/cli.py new file mode 100644 index 0000000..3aeb548 --- /dev/null +++ b/eagent/cli.py @@ -0,0 +1,898 @@ +"""eagent CLI entrypoint.""" + +from __future__ import annotations + +import asyncio +import os +import subprocess +import sys +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import click + +from eagent.commands.registry import ReloadRequested, execute_command, get_command_info_list +from eagent.context.agent_config import ( + AgentProfile, + AgentProfileSet, + load_agent_profiles, + set_active_profile, +) +from eagent.context.compaction import CompactParams +from eagent.context.compaction import compact as compact_messages +from eagent.context.git_context import get_git_context +from eagent.context.memory import load_agent_memory +from eagent.context.session_store import ( + init_session, + list_session_summaries_sync, + load_session, + save_message, +) +from eagent.core.agent_loop import agent_loop +from eagent.core.api_client import call_model, get_model_config +from eagent.core.types import ( + CommandContext, + Message, + PermissionDecision, + PermissionMode, + QueryParams, + SystemPromptBlock, + TextBlock, +) +from eagent.files.cache import create_file_state_cache +from eagent.files.history import create_file_history_state +from eagent.hooks import HookRuntime +from eagent.hooks.runtime import HookEventName +from eagent.mcp.manager import initialize_mcp_servers, shutdown_mcp_servers +from eagent.paths import env_root +from eagent.prompt.system_prompt import build_system_prompt_blocks +from eagent.reload import ReloadArgs +from eagent.skills.skill_tool import set_skill_query_params +from eagent.tools.agent_tool import set_agent_query_params +from eagent.tools.registry import initialize_tools, register_dynamic_tools +from eagent.tui.app import EnvAgentTui +from eagent.tui.status_bar import StatusMeta +from eagent.utils.completer import ResumeSuggestion, build_completer +from eagent.utils.cost import create_cost_tracker, summarize_cost +from eagent.utils.streaming import event_to_log_line + + +@dataclass +class CliState: + api_key: str + cwd: str + model_name: str + permission_mode: PermissionMode + enable_thinking: bool + thinking_budget: int | None + max_turns: int + enable_mcp: bool + dev_mode: bool + + session_id: str = "" + model_config: Any = None + tools: list[Any] = None + messages: list[Message] = None + read_file_state: Any = None + file_history: Any = None + cost_tracker: Any = None + system_prompt_blocks: list[Any] = None + profile_set: AgentProfileSet | None = None + active_profile: AgentProfile | None = None + api_base_url: str | None = None + reload_requested: bool = False + hook_runtime: HookRuntime | None = None + session_hooks_ran: bool = False + session_end_hooks_ran: bool = False + + def __post_init__(self) -> None: + self.session_id = self.session_id or str(uuid.uuid4()) + self.model_config = get_model_config(self.model_name) + self.tools = [] + self.messages = [] + self.read_file_state = create_file_state_cache() + self.file_history = create_file_history_state() + self.cost_tracker = create_cost_tracker() + self.system_prompt_blocks = [] + + def apply_profile(self, profile: AgentProfile) -> str: + self.active_profile = profile + self.model_config = get_model_config(profile.model) + self.api_key = profile.key + self.api_base_url = profile.base_url + return profile.name + + def set_model_value(self, value: str) -> str: + normalized = value.strip() + if not normalized: + raise ValueError("Model name cannot be empty.") + + if self.profile_set and self.profile_set.profiles: + matched = next((p for p in self.profile_set.profiles if p.name == normalized), None) + if matched: + self.profile_set = set_active_profile(matched.name, self.profile_set) + self.apply_profile(matched) + return f"{matched.name} ({matched.model})" + + self.active_profile = None + self.model_config = get_model_config(normalized) + return self.model_config.model + + +async def _default_permission_request(tool: str, _input: Any, message: str) -> PermissionDecision: + response = await asyncio.to_thread(click.confirm, f"Allow {tool}? {message}", default=False) + if response: + return PermissionDecision(behavior="allow") + return PermissionDecision(behavior="deny", message=f"Denied by user for {tool}") + + +async def _compact_call_model( + system_prompt: str, + prompt: str, + model: str, + api_key: str, + api_base_url: str | None, +) -> str: + synthetic = Message(role="user", content=[TextBlock(type="text", text=prompt)]) + model_config = get_model_config(model) + chunks: list[str] = [] + async for event in call_model( + messages=[synthetic], + tools=[], + model_config=model_config, + system_prompt_blocks=[SystemPromptBlock(type="text", text=system_prompt)], + api_key=api_key, + api_base_url=api_base_url, + ): + if event["type"] == "assistant_text": + chunks.append(event["text"]) + return "".join(chunks) + + +async def _manual_compact(state: CliState) -> None: + if not state.messages: + return + + async def _call(system_prompt: str, prompt: str, model: str, api_key: str) -> str: + return await _compact_call_model( + system_prompt, + prompt, + model, + api_key, + state.api_base_url, + ) + + result = await compact_messages( + state.messages, + CompactParams( + api_key=state.api_key, + model=state.model_config.model, + system_prompt_blocks=state.system_prompt_blocks, + ), + _call, + ) + state.messages = result.compacted + + +async def _resume_session(state: CliState, session_id: str) -> str | None: + loaded = await load_session(session_id) + if not loaded: + return f"Session {session_id} has no transcript or does not exist." + state.session_id = session_id + state.messages = loaded + return None + + +async def _new_session(state: CliState) -> str: + state.session_id = str(uuid.uuid4()) + state.messages.clear() + state.read_file_state = create_file_state_cache() + state.file_history = create_file_history_state() + await init_session(state.session_id, state.cwd) + state.session_hooks_ran = False + state.session_end_hooks_ran = False + return state.session_id + + +def _message_with_user_text(text: str) -> Message: + return Message(role="user", content=[TextBlock(type="text", text=text)], id=str(uuid.uuid4())) + + +def _last_assistant_message(messages: list[Message]) -> str: + for message in reversed(messages): + if message.role != "assistant": + continue + chunks: list[str] = [] + for block in message.content: + if isinstance(block, TextBlock): + text = block.text.strip() + if text: + chunks.append(text) + if chunks: + return "\n".join(chunks).strip() + return "" + + +def _emit_hook_debug_lines( + lines: list[str], + *, + event_sink: Callable[[dict[str, Any]], None] | None = None, + output_sink: Callable[[str], None] | None = None, +) -> None: + if not lines: + return + for line in lines: + event = {"type": "hook_debug", "text": line} + if event_sink is not None: + event_sink(event) + elif output_sink is not None: + output_sink(line) + else: + click.echo(line) + + +def _emit_hook_message( + message: str, + *, + event_sink: Callable[[dict[str, Any]], None] | None = None, + output_sink: Callable[[str], None] | None = None, +) -> None: + if event_sink is not None: + event_sink({"type": "error", "error": Exception(message)}) + return + if output_sink is not None: + output_sink(message) + return + click.echo(message, err=True) + + +async def _run_cli_hook_event( + state: CliState, + *, + event: HookEventName, + target: str, + variables: dict[str, Any] | None = None, + event_sink: Callable[[dict[str, Any]], None] | None = None, + output_sink: Callable[[str], None] | None = None, +) -> tuple[bool, list[str]]: + runtime = state.hook_runtime + if runtime is None: + return False, [] + + payload: dict[str, Any] = { + "session_id": state.session_id, + "cwd": state.cwd, + } + if variables: + payload.update(variables) + + outcome = await runtime.run( + event, + target=target, + variables=payload, + cwd=state.cwd, + dev_mode=state.dev_mode, + ) + _emit_hook_debug_lines(outcome.debug_lines, event_sink=event_sink, output_sink=output_sink) + + if outcome.aborted: + reason = outcome.abort_reason or f"Hook {event} aborted." + _emit_hook_message(reason, event_sink=event_sink, output_sink=output_sink) + return True, outcome.prompt_appends + + return False, outcome.prompt_appends + + +async def _ensure_session_start_hooks( + state: CliState, + *, + event_sink: Callable[[dict[str, Any]], None] | None = None, + output_sink: Callable[[str], None] | None = None, +) -> tuple[bool, list[str]]: + if state.session_hooks_ran: + return False, [] + + state.session_hooks_ran = True + return await _run_cli_hook_event( + state, + event="session_start", + target=state.session_id, + variables={"session_id": state.session_id}, + event_sink=event_sink, + output_sink=output_sink, + ) + + +async def _run_session_end_hooks( + state: CliState, + *, + event_sink: Callable[[dict[str, Any]], None] | None = None, + output_sink: Callable[[str], None] | None = None, +) -> tuple[bool, list[str]]: + if state.session_end_hooks_ran: + return False, [] + state.session_end_hooks_ran = True + return await _run_cli_hook_event( + state, + event="session_end", + target=state.session_id, + variables={"session_id": state.session_id}, + event_sink=event_sink, + output_sink=output_sink, + ) + + +async def _build_state( + api_key: str, + model: str, + cwd: str, + permission_mode: PermissionMode, + enable_thinking: bool, + thinking_budget: int | None, + max_turns: int, + enable_mcp: bool, + dev_mode: bool, + session_id: str | None, +) -> CliState: + state = CliState( + api_key=api_key, + cwd=cwd, + model_name=model, + permission_mode=permission_mode, + enable_thinking=enable_thinking, + thinking_budget=thinking_budget, + max_turns=max_turns, + enable_mcp=enable_mcp, + dev_mode=dev_mode, + ) + + state.profile_set = load_agent_profiles() + if state.profile_set.active: + state.apply_profile(state.profile_set.active) + + if model and model != "sonnet": + state.set_model_value(model) + + if session_id: + state.session_id = session_id + state.messages = await load_session(session_id) + else: + await init_session(state.session_id, cwd) + + agent_memory = await load_agent_memory(cwd) + git_context = await get_git_context(cwd) + state.system_prompt_blocks = build_system_prompt_blocks( + agent_memory, git_context, cwd, state.model_config.model + ) + + state.tools = await initialize_tools(cwd) + if enable_mcp: + mcp_tools = await initialize_mcp_servers(cwd) + if mcp_tools: + register_dynamic_tools(mcp_tools) + state.tools.extend(mcp_tools) + + state.hook_runtime = HookRuntime(cwd) + + return state + + +async def _run_agent_prompt( + state: CliState, + prompt: str, + text_sink: Callable[[str], None] | None = None, + event_sink: Callable[[dict[str, Any]], None] | None = None, + on_permission_request: Callable[[str, Any, str], Awaitable[PermissionDecision]] | None = None, +) -> None: + before = len(state.messages) + session_aborted, session_prompts = await _ensure_session_start_hooks( + state, + event_sink=event_sink, + ) + if session_aborted: + return + for extra_prompt in session_prompts: + if extra_prompt.strip(): + state.messages.append(_message_with_user_text(extra_prompt)) + + prompt_text = prompt.strip() + prompt_target = prompt_text if prompt_text else "(empty)" + user_aborted, user_prompts = await _run_cli_hook_event( + state, + event="user_prompt_submit", + target=prompt_target, + variables={ + "prompt": prompt, + "prompt_text": prompt_text, + }, + event_sink=event_sink, + ) + if user_aborted: + return + for extra_prompt in user_prompts: + if extra_prompt.strip(): + state.messages.append(_message_with_user_text(extra_prompt)) + + state.messages.append(_message_with_user_text(prompt)) + + params = QueryParams( + messages=state.messages, + tools=state.tools, + model_config=state.model_config, + system_prompt_blocks=state.system_prompt_blocks, + max_turns=state.max_turns, + permission_mode=state.permission_mode, + api_key=state.api_key, + api_base_url=state.api_base_url, + cwd=state.cwd, + session_id=state.session_id, + on_permission_request=on_permission_request or _default_permission_request, + enable_thinking=state.enable_thinking, + thinking_budget=state.thinking_budget, + read_file_state=state.read_file_state, + file_history=state.file_history, + hook_runtime=state.hook_runtime, + dev_mode=state.dev_mode, + ) + + set_agent_query_params(params) + set_skill_query_params(params) + + printed_text = False + stop_target = "turn_complete" + stop_error = "" + async for event in agent_loop(params): + event_type = event["type"] + if event_type == "assistant_text": + if text_sink is not None: + text_sink(event["text"]) + else: + click.echo(event["text"], nl=False) + printed_text = True + elif event_type == "usage": + state.cost_tracker.add(event["usage"]) + elif event_type == "error": + stop_target = "error" + stop_error = str(event.get("error") or "") + if event_sink is not None: + event_sink(event) + elif printed_text: + click.echo() + printed_text = False + if text_sink is None: + click.echo(f"Error: {event.get('error')}", err=True) + elif event_type in { + "tool_start", + "tool_result", + "compact", + "max_turns_reached", + "turn_complete", + "hook_debug", + }: + if event_type == "turn_complete": + stop_target = str(event.get("stop_reason") or "turn_complete") + elif event_type == "max_turns_reached": + stop_target = f"max_turns:{event.get('max_turns')}" + if event_sink is not None: + event_sink(event) + elif printed_text: + click.echo() + printed_text = False + click.echo(event_to_log_line(event)) + elif text_sink is None: + click.echo(event_to_log_line(event)) + + if printed_text and text_sink is None: + click.echo() + + stop_aborted, stop_prompts = await _run_cli_hook_event( + state, + event="stop", + target=stop_target, + variables={ + "stop_reason": stop_target, + "error": stop_error, + "last_assistant_message": _last_assistant_message(state.messages), + }, + event_sink=event_sink, + ) + if not stop_aborted: + for extra_prompt in stop_prompts: + if extra_prompt.strip(): + state.messages.append(_message_with_user_text(extra_prompt)) + + for message in state.messages[before:]: + await save_message(state.session_id, message, cwd=state.cwd) + + +async def _run_command( + state: CliState, + command_line: str, + output_sink: Callable[[str], None] | None = None, + event_sink: Callable[[dict[str, Any]], None] | None = None, + text_sink: Callable[[str], None] | None = None, + on_permission_request: Callable[[str, Any, str], Awaitable[PermissionDecision]] | None = None, + set_input_draft: Callable[[str], None] | None = None, + interactive: bool = False, +) -> bool: + pending_prompt: list[str] = [] + + raw = command_line[1:].strip() if command_line.startswith("/") else "" + command_name = "" + command_args = "" + if raw: + parts = raw.split(maxsplit=1) + command_name = parts[0] + command_args = parts[1] if len(parts) > 1 else "" + + session_aborted, session_prompts = await _ensure_session_start_hooks( + state, + event_sink=event_sink, + output_sink=output_sink, + ) + pending_prompt.extend(session_prompts) + if session_aborted: + return False + + if command_name: + before_aborted, before_prompts = await _run_cli_hook_event( + state, + event="before_command", + target=command_name, + variables={ + "command_name": command_name, + "command_args": command_args, + "command_line": command_line, + }, + event_sink=event_sink, + output_sink=output_sink, + ) + pending_prompt.extend(before_prompts) + if before_aborted: + return False + + command_context = CommandContext( + messages=state.messages, + tools=state.tools, + model_config=state.model_config, + cwd=state.cwd, + session_id=state.session_id, + cost_tracker=state.cost_tracker, + file_history=state.file_history, + read_file_state=state.read_file_state, + permission_mode=state.permission_mode, + set_permission_mode=lambda mode: setattr(state, "permission_mode", mode), + set_model=lambda model: state.set_model_value(model), + clear_messages=lambda: state.messages.clear(), + compact=lambda: _manual_compact(state), + resume_session=lambda sid: _resume_session(state, sid), + send_prompt=lambda text: pending_prompt.append(text), + set_input_draft=set_input_draft, + interactive=interactive, + new_session=lambda: _new_session(state), + dev_mode=state.dev_mode, + ) + + command_result: str | None = None + command_error: Exception | None = None + + try: + command_result = await execute_command(command_line, command_context) + except SystemExit: + return True + except ReloadRequested: + state.reload_requested = True + return True + except Exception as exc: + command_error = exc + + if command_error is not None: + on_error_aborted, on_error_prompts = await _run_cli_hook_event( + state, + event="on_error", + target=command_name or "command", + variables={ + "command_name": command_name, + "command_args": command_args, + "command_line": command_line, + "error": str(command_error), + "source": "command", + }, + event_sink=event_sink, + output_sink=output_sink, + ) + pending_prompt.extend(on_error_prompts) + message = f"Command error: {command_error}" + if output_sink is not None: + output_sink(message) + else: + click.echo(message, err=True) + if on_error_aborted: + return False + else: + if command_result: + if output_sink is not None: + output_sink(command_result) + else: + click.echo(command_result) + + if command_name: + after_aborted, after_prompts = await _run_cli_hook_event( + state, + event="after_command", + target=command_name, + variables={ + "command_name": command_name, + "command_args": command_args, + "command_line": command_line, + "result": command_result or "", + }, + event_sink=event_sink, + output_sink=output_sink, + ) + pending_prompt.extend(after_prompts) + if after_aborted: + return False + + while pending_prompt: + prompt = pending_prompt.pop(0) + await _run_agent_prompt( + state, + prompt, + text_sink=text_sink, + event_sink=event_sink, + on_permission_request=on_permission_request, + ) + + return False + + +def _status_bar(value: int, budget: int, width: int = 10) -> str: + if budget <= 0: + return "[" + "?" * width + "]" + ratio = max(0.0, min(1.0, value / budget)) + filled = int(round(ratio * width)) + if filled > width: + filled = width + return "[" + "#" * filled + "-" * (width - filled) + "]" + + +def _git_branch_status(cwd: str) -> str: + try: + branch = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True, + timeout=0.2, + check=False, + ) + except Exception: + return "" + if branch.returncode != 0: + return "" + value = branch.stdout.strip() + return value if value and value != "HEAD" else "" + + +def _status_meta(state: CliState) -> StatusMeta: + provider = state.active_profile.provider if state.active_profile else "anthropic" + return StatusMeta( + model=f"{provider}/{state.model_config.model}", + cwd=state.cwd, + git=_git_branch_status(state.cwd), + ) + + +def _build_tui_startup_messages(state: CliState) -> list[str]: + messages: list[str] = [] + if state.profile_set and state.profile_set.load_error: + messages.append(state.profile_set.load_error) + + if not state.api_key and os.environ.get("ENV_AGENT_MOCK", "").lower() not in {"1", "true", "yes"}: + messages.append( + "No active API key found. Configure ~/.env/agent.json or ANTHROPIC_API_KEY." + ) + return messages + + +async def _interactive_loop(state: CliState) -> None: + permission_request: Callable[[str, Any, str], Awaitable[PermissionDecision]] = ( + _default_permission_request + ) + input_draft_setter: Callable[[str], None] | None = None + + async def _on_prompt( + text: str, + assistant_sink: Callable[[str], None], + event_sink: Callable[[dict[str, Any]], None], + ) -> None: + await _run_agent_prompt( + state, + text, + text_sink=assistant_sink, + event_sink=event_sink, + on_permission_request=permission_request, + ) + + async def _on_command( + command_line: str, + output_sink: Callable[[str], None], + event_sink: Callable[[dict[str, Any]], None], + ) -> bool: + return await _run_command( + state, + command_line, + output_sink=output_sink, + event_sink=event_sink, + text_sink=output_sink, + on_permission_request=permission_request, + set_input_draft=input_draft_setter, + interactive=True, + ) + + def _list_agent_profiles() -> list[tuple[str, str]]: + if not state.profile_set: + return [] + return [ + (profile.name, f"{profile.name} ({profile.provider}, {profile.model})") + for profile in state.profile_set.profiles + ] + + async def _on_agent_select(profile_name: str) -> str: + changed = state.set_model_value(profile_name) + return f"Model changed to: {changed}" + + def _recent_session_suggestions() -> list[ResumeSuggestion]: + suggestions: list[ResumeSuggestion] = [ + ResumeSuggestion( + value=state.session_id[:8], + display=f"{state.session_id[:8]} current session {state.cwd}", + meta="current", + ) + ] + for summary in list_session_summaries_sync(limit=20): + label = summary.cwd or summary.id + detail = f"{summary.message_count} msgs {label}" if summary.message_count else label + suggestions.append( + ResumeSuggestion( + value=summary.prefix, + display=f"{summary.prefix} {detail}", + meta="recent session", + ) + ) + + deduped: list[ResumeSuggestion] = [] + seen: set[str] = set() + for suggestion in suggestions: + if not suggestion.value or suggestion.value in seen: + continue + seen.add(suggestion.value) + deduped.append(suggestion) + return deduped[:20] + + tui = EnvAgentTui( + session_id=state.session_id, + get_status_meta=lambda: _status_meta(state), + on_prompt=_on_prompt, + on_command=_on_command, + completer=build_completer( + model_suggestions=lambda: ( + [profile.name for profile in state.profile_set.profiles] + if state.profile_set + else [] + ), + resume_suggestions=_recent_session_suggestions, + workspace_root=state.cwd, + command_specs=get_command_info_list(state.cwd), + ), + startup_messages=_build_tui_startup_messages(state), + list_agents=_list_agent_profiles, + on_agent_select=_on_agent_select, + command_specs=get_command_info_list(state.cwd), + dev_mode=state.dev_mode, + ) + permission_request = tui.prompt_permission + input_draft_setter = tui.set_input_draft + await tui.run() + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option("--prompt", "prompt_text", default="", help="One-shot prompt to run and exit.") +@click.option( + "--model", default="sonnet", show_default=True, help="Model alias or full model name." +) +@click.option("--cwd", default=".", show_default=True, help="Working directory.") +@click.option( + "--permission-mode", + type=click.Choice(["default", "plan", "acceptEdits", "bypassPermissions"], case_sensitive=True), + default="default", + show_default=True, +) +@click.option("--max-turns", default=200, show_default=True, type=int) +@click.option("--enable-thinking/--no-enable-thinking", default=False, show_default=True) +@click.option("--thinking-budget", default=None, type=int) +@click.option("--session", "session_id", default=None, help="Existing session id to resume.") +@click.option("--enable-mcp/--no-enable-mcp", default=False, show_default=True) +@click.option("--dev/--no-dev", default=False, show_default=True, help="Enable development mode.") +def main( + prompt_text: str, + model: str, + cwd: str, + permission_mode: str, + max_turns: int, + enable_thinking: bool, + thinking_budget: int | None, + session_id: str | None, + enable_mcp: bool, + dev: bool, +) -> None: + """Run eagent CLI.""" + + api_key = os.environ.get("ANTHROPIC_API_KEY", "") + abs_cwd = os.path.abspath(cwd) + + async def _runner() -> bool: + state = await _build_state( + api_key=api_key, + model=model, + cwd=abs_cwd, + permission_mode=permission_mode, # type: ignore[arg-type] + enable_thinking=enable_thinking, + thinking_budget=thinking_budget, + max_turns=max_turns, + enable_mcp=enable_mcp, + dev_mode=dev, + session_id=session_id, + ) + + if not state.api_key and os.environ.get("ENV_AGENT_MOCK", "").lower() not in { + "1", + "true", + "yes", + }: + click.echo( + "Warning: no active key found (agent.json or ANTHROPIC_API_KEY). " + "Falling back to mock mode.", + err=True, + ) + + try: + if prompt_text: + if prompt_text.startswith("/"): + should_exit = await _run_command(state, prompt_text) + if should_exit: + if state.reload_requested: + click.echo( + "Reload requested from one-shot mode; restart skipped. " + "Use interactive mode with --dev." + ) + state.reload_requested = False + return False + else: + await _run_agent_prompt(state, prompt_text) + else: + await _interactive_loop(state) + finally: + await _run_session_end_hooks(state) + if enable_mcp: + await shutdown_mcp_servers() + + if state.reload_requested: + return True + + click.echo(summarize_cost(state.cost_tracker, state.model_config)) + return False + + reload_requested = asyncio.run(_runner()) + if reload_requested: + click.echo("Reloading RTE-AI (--dev)...") + os.execv(sys.executable, [sys.executable, *ReloadArgs.current()]) + + +if __name__ == "__main__": + main() diff --git a/eagent/commands/__init__.py b/eagent/commands/__init__.py new file mode 100644 index 0000000..61cf16a --- /dev/null +++ b/eagent/commands/__init__.py @@ -0,0 +1,5 @@ +"""Slash command package.""" + +from eagent.commands.registry import execute_command, get_command_info_list, get_commands + +__all__ = ["get_commands", "get_command_info_list", "execute_command"] diff --git a/eagent/commands/registry.py b/eagent/commands/registry.py new file mode 100644 index 0000000..98b7e55 --- /dev/null +++ b/eagent/commands/registry.py @@ -0,0 +1,506 @@ +"""Slash command registry.""" + +from __future__ import annotations + +import re +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from eagent.context.memory import load_agent_memory +from eagent.context.session_store import list_sessions +from eagent.context.token_counting import estimate_message_tokens +from eagent.core.types import CommandContext, SlashCommand +from eagent.paths import env_root +from eagent.skills.loader import parse_frontmatter +from eagent.skills.skill_tool import get_loaded_skills, initialize_skills + + +class ReloadRequested(Exception): + """Signal that interactive CLI should restart in dev mode.""" + + +@dataclass +class _Command: + name: str + description: str + handler: Callable[[str, CommandContext], Awaitable[str | None]] + argument_hint: str = "" + examples: list[str] = field(default_factory=list) + aliases: list[str] = field(default_factory=list) + + async def execute(self, args: str, context: CommandContext) -> str | None: + return await self.handler(args, context) + + +PROJECT_COMMANDS_DIR = Path(".agents") / "commands" +USER_COMMANDS_DIR = Path("commands") +CUSTOM_COMMAND_SEGMENT_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") + + +@dataclass(frozen=True) +class _CustomCommandDefinition: + name: str + description: str + argument_hint: str + template: str + source_path: Path + + +def _read_file_text(path: Path) -> str | None: + try: + content = path.read_text(encoding="utf-8") + except Exception: + return None + stripped = content.strip() + return stripped if stripped else None + + +def _command_search_roots(cwd: str) -> list[Path]: + project_root = Path(cwd).resolve() + roots = [project_root / PROJECT_COMMANDS_DIR, env_root() / USER_COMMANDS_DIR] + deduped: list[Path] = [] + seen: set[Path] = set() + for root in roots: + if root in seen: + continue + seen.add(root) + deduped.append(root) + return deduped + + +def _valid_custom_segment(name: str) -> bool: + return bool(name and CUSTOM_COMMAND_SEGMENT_PATTERN.fullmatch(name)) + + +def _build_custom_command_name(command_file: Path, commands_dir: Path) -> str | None: + relative = command_file.relative_to(commands_dir).with_suffix("") + parts = [part.strip() for part in relative.parts] + if not parts or any(not _valid_custom_segment(part) for part in parts): + return None + return ":".join(parts) + + +def _extract_markdown_title(content: str) -> str: + for line in content.splitlines(): + stripped = line.strip() + if not stripped: + continue + if stripped.startswith("#"): + stripped = stripped.lstrip("#").strip() + return stripped[:80] + return "User-defined command" + + +def _expand_command_template(template: str, args: str) -> str: + rendered = template.replace("$ARGUMENTS", args.strip()) + tokens = args.split() + for index in range(1, 10): + value = tokens[index - 1] if index - 1 < len(tokens) else "" + rendered = rendered.replace(f"${index}", value) + return rendered.strip() + + +def _parse_custom_command_file( + command_file: Path, commands_dir: Path +) -> _CustomCommandDefinition | None: + command_name = _build_custom_command_name(command_file, commands_dir) + if command_name is None: + return None + + content = _read_file_text(command_file) + if not content: + return None + + frontmatter, body = parse_frontmatter(content) + template = body.strip() + if not template: + return None + + description_raw = frontmatter.get("description") + argument_hint_raw = frontmatter.get("argument-hint") + description = ( + str(description_raw).strip() + if isinstance(description_raw, str) and description_raw.strip() + else _extract_markdown_title(template) + ) + argument_hint = ( + str(argument_hint_raw).strip() + if isinstance(argument_hint_raw, str) and argument_hint_raw.strip() + else "" + ) + return _CustomCommandDefinition( + name=command_name, + description=description, + argument_hint=argument_hint, + template=template, + source_path=command_file, + ) + + +def _builtin_names_and_aliases() -> set[str]: + names: set[str] = set() + for command in _COMMANDS: + names.add(command.name.lower()) + names.update(alias.lower() for alias in command.aliases) + return names + + +def _load_custom_command_definitions(cwd: str) -> list[_CustomCommandDefinition]: + blocked = _builtin_names_and_aliases() + discovered: dict[str, _CustomCommandDefinition] = {} + + for root in _command_search_roots(cwd): + commands_dir = root + if not commands_dir.exists() or not commands_dir.is_dir(): + continue + for command_file in sorted(commands_dir.rglob("*.md")): + definition = _parse_custom_command_file(command_file, commands_dir) + if definition is None: + continue + key = definition.name.lower() + if key in blocked or key in discovered: + continue + discovered[key] = definition + return list(discovered.values()) + + +def _resolve_custom_command(cwd: str, command_name: str) -> _CustomCommandDefinition | None: + lookup = command_name.strip().lower() + if not lookup: + return None + for definition in _load_custom_command_definitions(cwd): + if definition.name.lower() == lookup: + return definition + return None + + +async def _execute_custom_markdown_command( + command_name: str, args: str, ctx: CommandContext +) -> str | None: + definition = _resolve_custom_command(ctx.cwd, command_name) + if definition is None: + return f"Unknown command: /{command_name}. Use /help." + + prompt = _expand_command_template(definition.template, args) + if not prompt: + return f"Custom command /{definition.name} rendered empty prompt." + + cwd_path = Path(ctx.cwd).resolve() + source = ( + str(definition.source_path.relative_to(cwd_path)) + if definition.source_path.is_relative_to(cwd_path) + else str(definition.source_path) + ) + mode = ( + "interactive-insert" + if ctx.interactive and ctx.set_input_draft is not None + else "queued-send" + ) + + def _with_dev_logs(message: str) -> str: + if not ctx.dev_mode: + return message + debug_lines = [ + "[dev] custom command debug", + f"- name: /{definition.name}", + f"- source: {source}", + f"- args: {args!r}", + f"- mode: {mode}", + f"- template_chars: {len(definition.template)}", + f"- rendered_chars: {len(prompt)}", + ] + return message + "\n" + "\n".join(debug_lines) + + if ctx.interactive and ctx.set_input_draft is not None: + ctx.set_input_draft(prompt) + return _with_dev_logs( + f"Inserted custom command /{definition.name} into input from {source}." + ) + + ctx.send_prompt(prompt) + return _with_dev_logs(f"Queued custom command /{definition.name} from {source}.") + + +def _load_custom_commands(cwd: str) -> list[_Command]: + commands: list[_Command] = [] + for definition in _load_custom_command_definitions(cwd): + + async def _handler( + args: str, + ctx: CommandContext, + command_name: str = definition.name, + ) -> str | None: + return await _execute_custom_markdown_command(command_name, args, ctx) + + commands.append( + _Command( + name=definition.name, + description=definition.description, + handler=_handler, + argument_hint=definition.argument_hint, + examples=[f"/{definition.name}"], + ) + ) + return commands + + +async def _help(_args: str, ctx: CommandContext) -> str: + lines = ["Available commands:", ""] + for command in get_commands(ctx.cwd): + aliases = f" ({', '.join(command.aliases)})" if command.aliases else "" + lines.append(f" /{command.name.ljust(12)} {command.description}{aliases}") + lines.append("") + lines.append("Type a normal message to run the agent.") + return "\n".join(lines) + + +async def _compact(_args: str, ctx: CommandContext) -> str: + if not ctx.messages: + return "Nothing to compact." + await ctx.compact() + return "Compaction completed." + + +async def _clear(_args: str, ctx: CommandContext) -> str: + ctx.clear_messages() + return "Conversation cleared." + + +async def _model(args: str, ctx: CommandContext) -> str: + value = args.strip() + if value: + try: + changed = ctx.set_model(value) + except ValueError as exc: + return str(exc) + return f"Model changed to: {changed}" + return f"Current model: {ctx.model_config.model}" + + +async def _agent(_args: str, _ctx: CommandContext) -> str: + return "Use /agent in TUI to open the model picker, or /model to switch directly." + + +async def _cost(_args: str, ctx: CommandContext) -> str: + tracker = ctx.cost_tracker + config = ctx.model_config + return "\n".join( + [ + f"Model: {config.model}", + f"Turns: {tracker.turns}", + f"Input: {tracker.total_input_tokens:,}", + f"Output: {tracker.total_output_tokens:,}", + f"Cache R: {tracker.total_cache_read_tokens:,}", + f"Cache W: {tracker.total_cache_creation_tokens:,}", + f"Cost: ${tracker.total_cost_usd(config):.4f}", + ] + ) + + +async def _resume(args: str, ctx: CommandContext) -> str: + session_prefix = args.strip().lower() + sessions = await list_sessions() + if not session_prefix: + if not sessions: + return "No previous sessions found." + lines = ["Recent sessions:", ""] + for session in sessions[:10]: + lines.append( + " " + f"{session.get('id', '')[:8]} " + f"{session.get('cwd', '')} " + f"({session.get('messageCount', 0)} msgs)" + ) + lines.append("\nUse /resume to resume.") + return "\n".join(lines) + + match = next( + (s for s in sessions if str(s.get("id", "")).lower().startswith(session_prefix)), None + ) + if not match: + return f'No session found matching "{session_prefix}".' + + error = await ctx.resume_session(str(match["id"])) + if error: + return error + return f"Resumed session {str(match['id'])[:8]} ({match.get('messageCount', 0)} messages)." + + +async def _new(_args: str, ctx: CommandContext) -> str: + if ctx.new_session is None: + ctx.clear_messages() + return "Started a new empty context." + new_session_id = await ctx.new_session() + return f"Started new session {new_session_id[:8]} with empty context." + + +async def _plan(_args: str, ctx: CommandContext) -> str: + if ctx.permission_mode == "plan": + ctx.set_permission_mode("default") + return "Plan mode disabled." + ctx.set_permission_mode("plan") + return "Plan mode enabled (read-only)." + + +async def _memory(_args: str, ctx: CommandContext) -> str: + content = await load_agent_memory(ctx.cwd) + if not content: + return "No memory file found." + if len(content) > 3000: + return content[:3000] + f"\n\n... ({len(content)} chars total)" + return content + + +async def _config(_args: str, ctx: CommandContext) -> str: + return "\n".join( + [ + f"Model: {ctx.model_config.model}", + f"Context: {ctx.model_config.context_window:,}", + f"Max output: {ctx.model_config.max_output_tokens:,}", + f"Mode: {ctx.permission_mode}", + f"CWD: {ctx.cwd}", + f"Session: {ctx.session_id[:8]}", + f"Tools: {len(ctx.tools)} loaded", + ] + ) + + +async def _status(_args: str, ctx: CommandContext) -> str: + token_count = estimate_message_tokens(ctx.messages) + threshold = ctx.model_config.context_window - ctx.model_config.max_output_tokens - 13_000 + pct = int((token_count / threshold) * 100) if threshold > 0 else 0 + return "\n".join( + [ + f"Session: {ctx.session_id[:8]}", + f"Model: {ctx.model_config.model}", + f"Messages: {len(ctx.messages)}", + f"Tokens: {token_count:,} / {threshold:,} ({pct}%)", + f"Mode: {ctx.permission_mode}", + f"CWD: {ctx.cwd}", + f"Files mod: {len(ctx.file_history.tracked_files)}", + f"Snapshots: {len(ctx.file_history.snapshots)}", + ] + ) + + +async def _skills(_args: str, ctx: CommandContext) -> str: + skills = get_loaded_skills() + if not skills: + await initialize_skills(ctx.cwd) + skills = get_loaded_skills() + if not skills: + return "No skills loaded. Place skills in .agents/skills, ~/.agents/skills, or ~/.env/skills." + lines = ["Available skills:", ""] + for skill in skills: + lines.append(f" {skill.name} {skill.description}") + return "\n".join(lines) + + +async def _context(_args: str, ctx: CommandContext) -> str: + token_count = estimate_message_tokens(ctx.messages) + context_window = ctx.model_config.context_window + max_output = ctx.model_config.max_output_tokens + usable = context_window - max_output - 13_000 + pct = min(100, int((token_count / usable) * 100)) if usable > 0 else 0 + return "\n".join( + [ + f"Context Usage {ctx.model_config.model}", + f"Usage: {pct}%", + f"Total: {token_count:,} / {usable:,} (window: {context_window:,})", + ] + ) + + +async def _exit(_args: str, _ctx: CommandContext) -> str: + raise SystemExit(0) + + +async def _reload(_args: str, ctx: CommandContext) -> str: + if not ctx.dev_mode: + return "Reload is only available in dev mode. Restart in dev mode with `--dev`." + raise ReloadRequested() + + +_COMMANDS: list[_Command] = [ + _Command("help", "Show available commands", _help, examples=["/help"], aliases=["h", "?"]), + _Command("compact", "Compact conversation context", _compact), + _Command("clear", "Clear conversation history", _clear, aliases=["reset"]), + _Command( + "model", + "Show or change model", + _model, + argument_hint="", + examples=["/model kimi", "/model minimax"], + aliases=["m"], + ), + _Command("agent", "Open model picker (TUI)", _agent, examples=["/agent"]), + _Command("cost", "Show token usage and cost", _cost), + _Command( + "resume", + "Resume a previous session", + _resume, + argument_hint="", + examples=["/resume ab12cd34"], + aliases=["r"], + ), + _Command("new", "Create a new empty context/session", _new, examples=["/new"], aliases=["n"]), + _Command("plan", "Toggle plan mode (read-only)", _plan), + _Command("memory", "Show memory files", _memory), + _Command("config", "Show current configuration", _config), + _Command("status", "Show session status", _status), + _Command("skills", "List available skills", _skills), + _Command("context", "Show context window usage", _context, aliases=["ctx"]), + _Command( + "reload", + "Reload RTE-AI process (dev mode)", + _reload, + examples=["/reload"], + ), + _Command("exit", "Exit RTE-AI", _exit, aliases=["quit", "q"]), +] + + +def get_commands(cwd: str | None = None) -> list[SlashCommand]: + commands: list[_Command] = list(_COMMANDS) + if cwd: + commands.extend(_load_custom_commands(cwd)) + return list(commands) + + +def get_command_info_list(cwd: str | None = None) -> list[dict[str, Any]]: + return [ + { + "name": command.name, + "description": command.description, + "argument_hint": command.argument_hint, + "examples": list(command.examples), + "aliases": list(command.aliases), + } + for command in get_commands(cwd) + ] + + +async def execute_command(command_line: str, context: CommandContext) -> str | None: + if not command_line.startswith("/"): + return None + raw = command_line[1:].strip() + if not raw: + return None + + parts = raw.split(maxsplit=1) + command_name = parts[0] + args = parts[1] if len(parts) > 1 else "" + + command = next( + ( + c + for c in get_commands(context.cwd) + if c.name == command_name or command_name in c.aliases + ), + None, + ) + if command is None: + return f"Unknown command: /{command_name}. Use /help." + return await command.execute(args, context) diff --git a/eagent/context/__init__.py b/eagent/context/__init__.py new file mode 100644 index 0000000..8ec9b74 --- /dev/null +++ b/eagent/context/__init__.py @@ -0,0 +1,64 @@ +"""Context and session helpers.""" + +from eagent.context.agent_config import ( + AgentProfile, + AgentProfileSet, + get_agent_config_path, + load_agent_profiles, + save_agent_profiles, + set_active_profile, +) +from eagent.context.compaction import CompactParams, CompactResult, compact, should_auto_compact +from eagent.context.git_context import get_git_context, get_git_status_short +from eagent.context.memory import has_agent_memory, load_agent_memory +from eagent.context.post_compact import create_post_compact_attachments +from eagent.context.session_store import ( + SessionSummary, + init_session, + list_session_summaries, + list_session_summaries_sync, + list_sessions, + list_sessions_sync, + load_session, + save_message, +) +from eagent.context.token_counting import ( + estimate_json_tokens, + estimate_message_tokens, + estimate_single_message_tokens, + estimate_system_prompt_tokens, + estimate_tokens, + truncate_to_token_budget, +) + +__all__ = [ + "AgentProfile", + "AgentProfileSet", + "get_agent_config_path", + "load_agent_profiles", + "save_agent_profiles", + "set_active_profile", + "CompactParams", + "CompactResult", + "compact", + "should_auto_compact", + "get_git_context", + "get_git_status_short", + "has_agent_memory", + "load_agent_memory", + "create_post_compact_attachments", + "SessionSummary", + "init_session", + "list_session_summaries", + "list_session_summaries_sync", + "list_sessions", + "list_sessions_sync", + "load_session", + "save_message", + "estimate_tokens", + "estimate_json_tokens", + "estimate_single_message_tokens", + "estimate_message_tokens", + "estimate_system_prompt_tokens", + "truncate_to_token_budget", +] diff --git a/eagent/context/agent_config.py b/eagent/context/agent_config.py new file mode 100644 index 0000000..479f644 --- /dev/null +++ b/eagent/context/agent_config.py @@ -0,0 +1,126 @@ +"""Agent profile configuration from ~/.env/agent.json.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from eagent.paths import env_root + +AGENT_CONFIG_NAME = "agent.json" + + +@dataclass +class AgentProfile: + name: str + provider: str + model: str + key: str + base_url: str + + +@dataclass +class AgentProfileSet: + active_name: str | None + profiles: list[AgentProfile] + load_error: str | None = None + + @property + def active(self) -> AgentProfile | None: + if not self.profiles: + return None + if self.active_name: + for profile in self.profiles: + if profile.name == self.active_name: + return profile + return self.profiles[0] + + +def get_agent_config_path() -> Path: + return env_root() / AGENT_CONFIG_NAME + + +def _ensure_file() -> Path: + path = get_agent_config_path() + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + path.write_text('{"active": "", "profiles": []}\n', encoding="utf-8") + return path + + +def _parse_profile(raw: Any) -> AgentProfile | None: + if not isinstance(raw, dict): + return None + name = str(raw.get("name", "")).strip() + provider = str(raw.get("provider", "")).strip() + model = str(raw.get("model", "")).strip() + key = str(raw.get("key", "")).strip() + base_url = str(raw.get("base_url", "")).strip() + if not all([name, provider, model, key, base_url]): + return None + return AgentProfile(name=name, provider=provider, model=model, key=key, base_url=base_url) + + +def load_agent_profiles() -> AgentProfileSet: + path = _ensure_file() + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except Exception as exc: + return AgentProfileSet( + active_name=None, profiles=[], load_error=f"Invalid agent.json: {exc}" + ) + + if not isinstance(payload, dict): + return AgentProfileSet( + active_name=None, + profiles=[], + load_error="Invalid agent.json: top-level object is required.", + ) + + raw_profiles = payload.get("profiles") + profiles = [] + if isinstance(raw_profiles, list): + profiles = [profile for profile in (_parse_profile(p) for p in raw_profiles) if profile] + + active_name = str(payload.get("active", "")).strip() or None + if profiles and ( + active_name is None or all(profile.name != active_name for profile in profiles) + ): + active_name = profiles[0].name + save_agent_profiles(AgentProfileSet(active_name=active_name, profiles=profiles)) + + return AgentProfileSet(active_name=active_name, profiles=profiles) + + +def save_agent_profiles(profile_set: AgentProfileSet) -> None: + path = _ensure_file() + active_name = profile_set.active_name or (profile_set.active.name if profile_set.active else "") + payload = { + "active": active_name, + "profiles": [ + { + "name": p.name, + "provider": p.provider, + "model": p.model, + "key": p.key, + "base_url": p.base_url, + } + for p in profile_set.profiles + ], + } + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + + +def set_active_profile(name: str, profile_set: AgentProfileSet) -> AgentProfileSet: + match = next((profile for profile in profile_set.profiles if profile.name == name), None) + if match is None: + raise KeyError(f'Profile "{name}" not found.') + updated = AgentProfileSet( + active_name=match.name, + profiles=profile_set.profiles, + load_error=profile_set.load_error, + ) + save_agent_profiles(updated) + return updated diff --git a/eagent/context/compaction.py b/eagent/context/compaction.py new file mode 100644 index 0000000..bb236b6 --- /dev/null +++ b/eagent/context/compaction.py @@ -0,0 +1,73 @@ +"""Conversation compaction logic.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from eagent.context.token_counting import estimate_message_tokens +from eagent.core.types import Message, SystemPromptBlock, TextBlock +from eagent.prompt.compact_prompt import COMPACT_PROMPT, format_compact_summary + +PRESERVE_RECENT_TURNS = 3 +SAFETY_MARGIN = 13_000 + + +@dataclass +class CompactParams: + api_key: str + model: str + system_prompt_blocks: list[SystemPromptBlock] | None = None + custom_prompt: str | None = None + + +@dataclass +class CompactResult: + compacted: list[Message] + old_tokens: int + new_tokens: int + + +async def compact( + messages: list[Message], + params: CompactParams, + call_model: Callable[[str, str, str, str], Awaitable[str]], +) -> CompactResult: + old_tokens = estimate_message_tokens(messages) + if len(messages) <= PRESERVE_RECENT_TURNS * 2: + return CompactResult(compacted=messages, old_tokens=old_tokens, new_tokens=old_tokens) + + keep = messages[-(PRESERVE_RECENT_TURNS * 2) :] + summarize = messages[: -(PRESERVE_RECENT_TURNS * 2)] + + serialized: list[str] = [] + for message in summarize: + serialized.append(message.role.upper()) + for block in message.content: + if getattr(block, "type", None) == "text": + serialized.append(getattr(block, "text", "")) + elif getattr(block, "type", None) == "tool_use": + serialized.append( + f"[TOOL_USE] {getattr(block, 'name', '')} {getattr(block, 'input', {})}" + ) + elif getattr(block, "type", None) == "tool_result": + serialized.append(f"[TOOL_RESULT] {getattr(block, 'content', '')}") + + prompt = (params.custom_prompt or COMPACT_PROMPT).strip() + "\n\n" + "\n".join(serialized) + system_prompt = "You are a summarizer. Preserve file paths, commands, and decisions." + summary = await call_model(system_prompt, prompt, params.model, params.api_key) + + summary_msg = Message( + role="user", + content=[TextBlock(type="text", text=format_compact_summary(summary))], + ) + compacted = [summary_msg, *keep] + new_tokens = estimate_message_tokens(compacted) + return CompactResult(compacted=compacted, old_tokens=old_tokens, new_tokens=new_tokens) + + +def should_auto_compact( + messages: list[Message], context_window: int, max_output_tokens: int +) -> bool: + threshold = context_window - max_output_tokens - SAFETY_MARGIN + return threshold > 0 and estimate_message_tokens(messages) >= threshold diff --git a/eagent/context/git_context.py b/eagent/context/git_context.py new file mode 100644 index 0000000..758a19a --- /dev/null +++ b/eagent/context/git_context.py @@ -0,0 +1,65 @@ +"""Git context helper for prompts.""" + +from __future__ import annotations + +import asyncio + +_CACHE: dict[str, tuple[float, str]] = {} +_TTL = 300.0 + + +async def _git(cwd: str, *args: str) -> str | None: + try: + proc = await asyncio.create_subprocess_exec( + "git", + *args, + cwd=cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _stderr = await asyncio.wait_for(proc.communicate(), timeout=5.0) + if proc.returncode != 0: + return None + return stdout.decode("utf-8", errors="replace").strip() + except Exception: + return None + + +async def get_git_context(cwd: str) -> str: + import time + + now = time.time() + cached = _CACHE.get(cwd) + if cached and now - cached[0] < _TTL: + return cached[1] + + inside = await _git(cwd, "rev-parse", "--is-inside-work-tree") + if inside != "true": + result = "Not a git repository." + _CACHE[cwd] = (now, result) + return result + + branch = await _git(cwd, "rev-parse", "--abbrev-ref", "HEAD") + status = await _git(cwd, "status", "--porcelain") + log = await _git(cwd, "log", "--oneline", "-10") + + lines: list[str] = [] + if branch: + lines.append(f"Current branch: {branch}") + if status is not None: + lines.append("Status: Clean working tree" if not status else f"Status:\n{status}") + if log: + lines.append(f"Recent commits:\n{log}") + + result = "\n\n".join(lines) + _CACHE[cwd] = (now, result) + return result + + +async def get_git_status_short(cwd: str) -> str: + branch = await _git(cwd, "rev-parse", "--abbrev-ref", "HEAD") + if not branch: + return "" + status = await _git(cwd, "status", "--porcelain") + count = len([line for line in (status or "").splitlines() if line.strip()]) + return f"{branch} (clean)" if count == 0 else f"{branch} ({count} changed)" diff --git a/eagent/context/memory.py b/eagent/context/memory.py new file mode 100644 index 0000000..b40269e --- /dev/null +++ b/eagent/context/memory.py @@ -0,0 +1,55 @@ +"""Load project/user AI context markdown files.""" + +from __future__ import annotations + +from pathlib import Path + +from eagent.paths import env_root + +CONTEXT_FILES = [ + "tasks.md", + "requirements.md", + "requirements", + "design.md", + "ENV_AGENT.md", + "ENV_AGENT.local.md", +] +LEGACY_PROJECT_FILES = [ + "ENV_AGENT.md", + "ENV_AGENT.local.md", +] + + +def _try_read(path: Path) -> str | None: + try: + text = path.read_text(encoding="utf-8") + return text.strip() if text.strip() else None + except Exception: + return None + + +async def load_agent_memory(cwd: str) -> str: + base = Path(cwd).resolve() + fragments: list[str] = [] + seen_sources: set[Path] = set() + + source_paths: list[Path] = [] + source_paths.extend(base / ".agents" / rel for rel in CONTEXT_FILES) + source_paths.extend(base / rel for rel in LEGACY_PROJECT_FILES) + source_paths.extend(env_root() / rel for rel in CONTEXT_FILES) + + for p in source_paths: + resolved = p.resolve() + if resolved in seen_sources: + continue + content = _try_read(p) + if content: + seen_sources.add(resolved) + source = str(p.relative_to(base)) if p.is_relative_to(base) else str(p) + fragments.append(f"# Source: {source}\n\n{content}") + + return "\n\n---\n\n".join(fragments) + + +async def has_agent_memory(cwd: str) -> bool: + return bool(await load_agent_memory(cwd)) diff --git a/eagent/context/post_compact.py b/eagent/context/post_compact.py new file mode 100644 index 0000000..ec02a93 --- /dev/null +++ b/eagent/context/post_compact.py @@ -0,0 +1,15 @@ +"""Post-compact file attachment refresh.""" + +from __future__ import annotations + +from eagent.core.types import Message + + +async def create_post_compact_attachments( + preserved_messages: list[Message], read_file_state, context_budget: int = 50_000 +) -> list[Message]: + _ = preserved_messages + _ = read_file_state + _ = context_budget + # Simplified placeholder: keep behavior optional. + return [] diff --git a/eagent/context/session_store.py b/eagent/context/session_store.py new file mode 100644 index 0000000..ceb6eb5 --- /dev/null +++ b/eagent/context/session_store.py @@ -0,0 +1,248 @@ +"""Session persistence in JSONL format.""" + +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from eagent.paths import env_root +from eagent.core.types import ( + ImageBlock, + ImageSource, + Message, + RedactedThinkingBlock, + TextBlock, + ThinkingBlock, + ToolResultBlock, + ToolUseBlock, +) + +SESSIONS_BASE = env_root() / "sessions" +TRANSCRIPT_FILE = "transcript.jsonl" +META_FILE = "meta.json" + + +@dataclass(frozen=True) +class SessionSummary: + id: str + cwd: str + message_count: int + updated_at: int + + @property + def prefix(self) -> str: + return self.id[:8] + + +def get_session_dir(session_id: str) -> Path: + return SESSIONS_BASE / session_id + + +def get_session_path(session_id: str) -> Path: + return get_session_dir(session_id) / TRANSCRIPT_FILE + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _serialize_message(message: Message) -> dict[str, Any]: + blocks: list[dict[str, Any]] = [] + for block in message.content: + payload = dict(vars(block)) + if isinstance(block, ImageBlock): + payload["source"] = dict(vars(block.source)) + blocks.append(payload) + return {"role": message.role, "content": blocks, "id": message.id} + + +def _deserialize_message(raw: dict[str, Any]) -> Message: + content = [] + for block in raw.get("content", []): + if not isinstance(block, dict): + continue + btype = block.get("type") + if btype == "text": + content.append(TextBlock(type="text", text=str(block.get("text", "")))) + elif btype == "tool_use": + content.append( + ToolUseBlock( + type="tool_use", + id=str(block.get("id", "")), + name=str(block.get("name", "")), + input=block.get("input", {}) if isinstance(block.get("input"), dict) else {}, + ) + ) + elif btype == "tool_result": + content.append( + ToolResultBlock( + type="tool_result", + tool_use_id=str(block.get("tool_use_id", "")), + content=block.get("content", ""), + is_error=bool(block.get("is_error", False)), + ) + ) + elif btype == "thinking": + content.append(ThinkingBlock(type="thinking", thinking=str(block.get("thinking", "")))) + elif btype == "redacted_thinking": + content.append( + RedactedThinkingBlock(type="redacted_thinking", data=str(block.get("data", ""))) + ) + elif btype == "image": + source = block.get("source", {}) + media_type = ( + source.get("media_type", "image/png") if isinstance(source, dict) else "image/png" + ) + data = source.get("data", "") if isinstance(source, dict) else "" + content.append( + ImageBlock( + type="image", + source=ImageSource(type="base64", media_type=str(media_type), data=str(data)), + ) + ) + role = raw.get("role", "user") + return Message( + role=role if role in {"user", "assistant"} else "user", content=content, id=raw.get("id") + ) + + +def _write_meta( + session_id: str, cwd: str, message_count: int, created_at: int | None = None +) -> None: + folder = get_session_dir(session_id) + folder.mkdir(parents=True, exist_ok=True) + now = _now_ms() + meta = { + "id": session_id, + "createdAt": created_at if created_at is not None else now, + "updatedAt": now, + "cwd": cwd, + "messageCount": message_count, + } + (folder / META_FILE).write_text( + json.dumps(meta, indent=2, ensure_ascii=False) + "\n", encoding="utf-8" + ) + + +async def init_session(session_id: str, cwd: str) -> None: + _write_meta(session_id, cwd=cwd, message_count=0) + + +async def save_message(session_id: str, message: Message, cwd: str | None = None) -> None: + folder = get_session_dir(session_id) + folder.mkdir(parents=True, exist_ok=True) + + entry = { + "type": message.role, + "message": _serialize_message(message), + "timestamp": _now_ms(), + "id": str(uuid.uuid4()), + } + with get_session_path(session_id).open("a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + existing = None + meta_path = folder / META_FILE + if meta_path.exists(): + try: + existing = json.loads(meta_path.read_text(encoding="utf-8")) + except Exception: + existing = None + + message_count = 1 + if existing and isinstance(existing.get("messageCount"), int): + message_count = int(existing["messageCount"]) + 1 + + _write_meta( + session_id, + cwd=cwd or (existing.get("cwd") if isinstance(existing, dict) else ""), + message_count=message_count, + created_at=existing.get("createdAt") if isinstance(existing, dict) else None, + ) + + +async def load_session(session_id: str) -> list[Message]: + path = get_session_path(session_id) + if not path.exists(): + return [] + + messages: list[Message] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + try: + entry = json.loads(line) + raw_message = entry.get("message") + if isinstance(raw_message, dict): + messages.append(_deserialize_message(raw_message)) + except Exception: + continue + return messages + + +def list_sessions_sync() -> list[dict[str, Any]]: + SESSIONS_BASE.mkdir(parents=True, exist_ok=True) + sessions: list[dict[str, Any]] = [] + + for folder in SESSIONS_BASE.iterdir(): + if not folder.is_dir(): + continue + meta_path = folder / META_FILE + if meta_path.exists(): + try: + sessions.append(json.loads(meta_path.read_text(encoding="utf-8"))) + continue + except Exception: + pass + + transcript = folder / TRANSCRIPT_FILE + mtime = int(transcript.stat().st_mtime * 1000) if transcript.exists() else 0 + count = 0 + if transcript.exists(): + count = sum(1 for _ in transcript.open("r", encoding="utf-8")) + + sessions.append( + { + "id": folder.name, + "createdAt": mtime, + "updatedAt": mtime, + "cwd": "", + "messageCount": count, + } + ) + + sessions.sort(key=lambda x: x.get("updatedAt", 0), reverse=True) + return sessions + + +async def list_sessions() -> list[dict[str, Any]]: + return list_sessions_sync() + + +def list_session_summaries_sync(limit: int | None = None) -> list[SessionSummary]: + summaries: list[SessionSummary] = [] + for session in list_sessions_sync(): + session_id = str(session.get("id") or "") + if not session_id: + continue + count = session.get("messageCount") + updated = session.get("updatedAt") + summaries.append( + SessionSummary( + id=session_id, + cwd=str(session.get("cwd") or ""), + message_count=int(count) if isinstance(count, int) else 0, + updated_at=int(updated) if isinstance(updated, int) else 0, + ) + ) + if limit is not None and len(summaries) >= limit: + break + return summaries + + +async def list_session_summaries(limit: int | None = None) -> list[SessionSummary]: + return list_session_summaries_sync(limit) diff --git a/eagent/context/token_counting.py b/eagent/context/token_counting.py new file mode 100644 index 0000000..01e4518 --- /dev/null +++ b/eagent/context/token_counting.py @@ -0,0 +1,66 @@ +"""Fast token estimation utilities.""" + +from __future__ import annotations + +from eagent.core.types import Message, SystemPromptBlock + +CHARS_PER_TOKEN = 4 +JSON_CHARS_PER_TOKEN = 2 + + +def estimate_tokens(text: str) -> int: + if not text: + return 0 + return round(len(text) / CHARS_PER_TOKEN) + + +def estimate_json_tokens(obj: object) -> int: + import json + + try: + return round(len(json.dumps(obj, ensure_ascii=False)) / JSON_CHARS_PER_TOKEN) + except Exception: + return 0 + + +def estimate_single_message_tokens(message: Message) -> int: + total = 4 + for block in message.content: + btype = getattr(block, "type", None) + if btype == "text": + total += estimate_tokens(getattr(block, "text", "")) + elif btype == "tool_use": + total += estimate_tokens(getattr(block, "name", "")) + total += estimate_json_tokens(getattr(block, "input", {})) + elif btype == "tool_result": + content = getattr(block, "content", "") + if isinstance(content, str): + total += estimate_tokens(content) + else: + total += estimate_json_tokens(content) + elif btype == "thinking": + total += estimate_tokens(getattr(block, "thinking", "")) + elif btype == "redacted_thinking": + total += estimate_tokens(getattr(block, "data", "")) + elif btype == "image": + total += 1500 + return total + + +def estimate_message_tokens(messages: list[Message]) -> int: + return sum(estimate_single_message_tokens(m) for m in messages) + + +def estimate_system_prompt_tokens(blocks: list[SystemPromptBlock]) -> int: + return sum(estimate_tokens(b.text) for b in blocks) + + +def truncate_to_token_budget(text: str, max_tokens: int) -> str: + if estimate_tokens(text) <= max_tokens: + return text + limit = max_tokens * CHARS_PER_TOKEN + clipped = text[:limit] + idx = max(clipped.rfind("\n"), clipped.rfind(" ")) + if idx > 0: + return f"{clipped[:idx]}\n...[truncated]" + return f"{clipped}...[truncated]" diff --git a/eagent/core/__init__.py b/eagent/core/__init__.py new file mode 100644 index 0000000..e4d58b5 --- /dev/null +++ b/eagent/core/__init__.py @@ -0,0 +1,6 @@ +"""Core package marker. + +Avoid importing heavy modules here to prevent circular imports. +""" + +__all__: list[str] = [] diff --git a/eagent/core/agent_loop.py b/eagent/core/agent_loop.py new file mode 100644 index 0000000..e904d79 --- /dev/null +++ b/eagent/core/agent_loop.py @@ -0,0 +1,415 @@ +"""Core agent orchestration loop.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from typing import Any + +from eagent.context.compaction import CompactParams +from eagent.context.compaction import compact as compact_messages +from eagent.context.token_counting import estimate_message_tokens +from eagent.core.api_client import call_model, get_model_config +from eagent.core.errors import AbortError, PromptTooLongError, classify_error +from eagent.core.streaming_executor import execute_tools +from eagent.core.types import ( + ContentBlock, + Message, + QueryParams, + StreamEvent, + SystemPromptBlock, + TextBlock, + ToolContext, + ToolResult, + ToolResultBlock, + ToolUseBlock, +) +from eagent.permissions.engine import PermissionContext, check_permission + +AUTOCOMPACT_BUFFER_TOKENS = 13_000 +MAX_PTL_RETRIES = 3 +MICRO_COMPACT_THRESHOLD_CHARS = 50_000 + + +def _should_auto_compact( + messages: list[Message], context_window: int, max_output_tokens: int +) -> bool: + threshold = context_window - max_output_tokens - AUTOCOMPACT_BUFFER_TOKENS + return estimate_message_tokens(messages) > threshold + + +def _micro_compact_messages(messages: list[Message]) -> None: + threshold_index = max(0, len(messages) - 6) + for i in range(threshold_index): + message = messages[i] + if message.role != "user": + continue + for idx, block in enumerate(message.content): + if getattr(block, "type", None) != "tool_result": + continue + content = block.content if isinstance(block.content, str) else str(block.content) + if len(content) > MICRO_COMPACT_THRESHOLD_CHARS: + message.content[idx] = ToolResultBlock( + type="tool_result", + tool_use_id=block.tool_use_id, + content=( + content[:5000] + + f"\n\n[Content truncated: was {len(content)} chars. Re-read if needed.]" + ), + is_error=block.is_error, + ) + + +def _truncate_for_ptl(messages: list[Message]) -> list[Message]: + if len(messages) <= 2: + return messages + truncated = list(messages) + del truncated[:2] + return truncated + + +def _assistant_message(content: list[ContentBlock]) -> Message: + return Message(role="assistant", content=content, id=str(uuid.uuid4())) + + +def _tool_result_message(results: list[tuple[str, ToolResult]]) -> Message: + return Message( + role="user", + content=[ + ToolResultBlock( + type="tool_result", + tool_use_id=tool_use_id, + content=result.result, + is_error=result.is_error, + ) + for tool_use_id, result in results + ], + id=str(uuid.uuid4()), + ) + + +def _hook_prompt_message(prompt: str) -> Message: + return Message( + role="user", + content=[TextBlock(type="text", text=prompt)], + id=str(uuid.uuid4()), + ) + + +def _describe_tool_use(tool_use: ToolUseBlock) -> str: + input_data = tool_use.input + if tool_use.name == "Bash": + command = input_data.get("command") or "(no command)" + return f"Bash: {command}" + file_path = input_data.get("file_path") or input_data.get("path") or input_data.get("filePath") + if file_path: + return f"{tool_use.name}: {file_path}" + return f"{tool_use.name}: {str(input_data)[:200]}" + + +async def _perform_compaction( + messages: list[Message], params: QueryParams +) -> tuple[list[Message], int, int]: + async def _call_model_for_compact( + system_prompt: str, prompt: str, model: str, api_key: str + ) -> str: + effective_api_key = params.api_key_override or api_key + synthetic_message = Message(role="user", content=[TextBlock(type="text", text=prompt)]) + model_config = get_model_config(model) + text_parts: list[str] = [] + async for event in call_model( + messages=[synthetic_message], + tools=[], + model_config=model_config, + system_prompt_blocks=[SystemPromptBlock(type="text", text=system_prompt)], + api_key=effective_api_key, + api_base_url=params.api_base_url, + ): + if event["type"] == "assistant_text": + text_parts.append(event["text"]) + return "".join(text_parts) + + result = await compact_messages( + messages, + CompactParams( + api_key=params.api_key, + model=params.model_config.model, + system_prompt_blocks=params.system_prompt_blocks, + ), + _call_model_for_compact, + ) + return result.compacted, result.old_tokens, result.new_tokens + + +async def _check_tool_permission( + tool_use: ToolUseBlock, params: QueryParams, context: ToolContext +) -> tuple[bool, str | None]: + tool = next((t for t in params.tools if t.name == tool_use.name), None) + if tool is None: + return True, None + + # Evaluate rule engine first. + decision = await check_permission( + tool_use.name, + tool_use.input, + PermissionContext( + cwd=params.cwd, permission_mode=params.permission_mode, tools=params.tools + ), + ) + + if decision.behavior == "deny": + return False, decision.message or f"Permission denied for {tool_use.name}." + if decision.behavior == "allow": + return True, None + + # Fallback to runtime permission callback. + user_decision = await params.on_permission_request( + tool_use.name, + tool_use.input, + _describe_tool_use(tool_use), + ) + return user_decision.behavior == "allow", user_decision.message + + +async def agent_loop(params: QueryParams) -> AsyncGenerator[StreamEvent, None]: + messages = params.messages + turn_count = 0 + ptl_retries = 0 + hook_prompt_appends: list[str] = [] + + tool_context = ToolContext( + cwd=params.cwd, + read_file_state=params.read_file_state, + file_history=params.file_history, + modified_files=set(), + session_id=params.session_id, + abort_signal=params.abort_signal, + permission_mode=params.permission_mode, + on_permission_request=params.on_permission_request, + hook_runtime=params.hook_runtime, + on_hook_prompt_append=hook_prompt_appends.append, + dev_mode=params.dev_mode, + ) + + async def _run_on_error_hooks( + *, + target: str, + source: str, + error_value: Any, + extra: dict[str, Any] | None = None, + ) -> AsyncGenerator[StreamEvent, None]: + if params.hook_runtime is None: + return + + variables: dict[str, Any] = { + "source": source, + "target": target, + "error": str(error_value), + "session_id": params.session_id, + } + if extra: + variables.update(extra) + + outcome = await params.hook_runtime.run( + "on_error", + target=target, + variables=variables, + cwd=params.cwd, + dev_mode=params.dev_mode, + ) + if params.dev_mode: + for line in outcome.debug_lines: + yield {"type": "hook_debug", "text": line} + if outcome.prompt_appends: + for prompt in outcome.prompt_appends: + messages.append(_hook_prompt_message(prompt)) + + while True: + if getattr(params.abort_signal, "aborted", False): + yield {"type": "error", "error": AbortError("Aborted")} + return + + if _should_auto_compact( + messages, params.model_config.context_window, params.model_config.max_output_tokens + ): + try: + compacted, old_tokens, new_tokens = await _perform_compaction(messages, params) + messages[:] = compacted + yield {"type": "compact", "old_tokens": old_tokens, "new_tokens": new_tokens} + except Exception as exc: + async for hook_event in _run_on_error_hooks( + target="compact", + source="compact", + error_value=exc, + ): + yield hook_event + yield {"type": "error", "error": Exception(f"Auto-compact failed: {exc}")} + + _micro_compact_messages(messages) + + tool_uses: list[ToolUseBlock] = [] + assistant_content: list[ContentBlock] = [] + stop_reason = "end_turn" + + try: + async for event in call_model( + messages=messages, + tools=params.tools, + model_config=params.model_config, + system_prompt_blocks=params.system_prompt_blocks, + api_key=params.api_key_override or params.api_key, + api_base_url=params.api_base_url, + enable_thinking=params.enable_thinking, + thinking_budget=params.thinking_budget, + abort_signal=params.abort_signal, + ): + yield event + if event["type"] == "tool_use": + tool_uses.append(event["tool_use"]) + elif event["type"] == "assistant_message": + assistant_content = event["message"].content + elif event["type"] == "turn_complete": + stop_reason = event["stop_reason"] + except PromptTooLongError: + async for hook_event in _run_on_error_hooks( + target="model", + source="model", + error_value="prompt_too_long", + ): + yield hook_event + ptl_retries += 1 + if ptl_retries >= MAX_PTL_RETRIES: + yield { + "type": "error", + "error": Exception( + "Prompt too long after " + f"{MAX_PTL_RETRIES} recovery attempts. " + "Use /compact and retry." + ), + } + return + messages[:] = _truncate_for_ptl(messages) + yield { + "type": "error", + "error": Exception( + "Prompt too long, truncating old turns " + f"(attempt {ptl_retries}/{MAX_PTL_RETRIES})." + ), + } + continue + except Exception as exc: + classified = classify_error(exc) + async for hook_event in _run_on_error_hooks( + target="model", + source="model", + error_value=classified, + ): + yield hook_event + yield {"type": "error", "error": classified} + return + + ptl_retries = 0 + if assistant_content: + messages.append(_assistant_message(assistant_content)) + + if not tool_uses: + yield {"type": "turn_complete", "stop_reason": stop_reason} + return + + tool_results: list[tuple[str, ToolResult]] = [] + allowed_tools: list[ToolUseBlock] = [] + + for tool_use in tool_uses: + allowed, message = await _check_tool_permission(tool_use, params, tool_context) + if allowed: + allowed_tools.append(tool_use) + continue + + denied = ToolResult( + result=message or f"Permission denied for {tool_use.name}.", is_error=True + ) + tool_results.append((tool_use.id, denied)) + yield { + "type": "tool_result", + "tool_use_id": tool_use.id, + "tool_name": tool_use.name, + "result": denied.result, + "is_error": True, + } + + if allowed_tools: + async for event in execute_tools(allowed_tools, params.tools, tool_context): + yield event + if event["type"] == "tool_result": + tool_results.append( + ( + event["tool_use_id"], + ToolResult(result=event["result"], is_error=event["is_error"]), + ) + ) + + ordered: list[tuple[str, ToolResult]] = [] + for tool_use in tool_uses: + found = next((r for r in tool_results if r[0] == tool_use.id), None) + ordered.append( + found or (tool_use.id, ToolResult(result="Tool execution failed", is_error=True)) + ) + + messages.append(_tool_result_message(ordered)) + if hook_prompt_appends: + for prompt in hook_prompt_appends: + messages.append(_hook_prompt_message(prompt)) + hook_prompt_appends.clear() + + turn_count += 1 + if turn_count >= params.max_turns: + yield {"type": "max_turns_reached", "max_turns": params.max_turns} + return + + +async def run_sub_agent( + prompt: str, + params: QueryParams, + tools: list[str] | None = None, + disallowed_tools: list[str] | None = None, + max_turns: int | None = None, + model: str | None = None, +) -> str: + filtered_tools = params.tools + if tools is not None: + allow = set(tools) + filtered_tools = [tool for tool in filtered_tools if tool.name in allow] + if disallowed_tools is not None: + deny = set(disallowed_tools) + filtered_tools = [tool for tool in filtered_tools if tool.name not in deny] + + sub_messages = [ + Message(role="user", content=[TextBlock(type="text", text=prompt)], id=str(uuid.uuid4())) + ] + + sub_params = QueryParams( + messages=sub_messages, + tools=filtered_tools, + model_config=get_model_config(model) if model else params.model_config, + system_prompt_blocks=params.system_prompt_blocks, + permission_mode=params.permission_mode, + api_key=params.api_key, + cwd=params.cwd, + session_id=params.session_id, + on_permission_request=params.on_permission_request, + read_file_state=params.read_file_state.clone(), + file_history=params.file_history, + max_turns=max_turns or params.max_turns, + abort_signal=params.abort_signal, + enable_thinking=params.enable_thinking, + thinking_budget=params.thinking_budget, + api_key_override=params.api_key_override, + api_base_url=params.api_base_url, + ) + + chunks: list[str] = [] + async for event in agent_loop(sub_params): + if event["type"] == "assistant_text": + chunks.append(event["text"]) + params.read_file_state.merge(sub_params.read_file_state) + return "".join(chunks).strip() or "(No response from sub-agent)" diff --git a/eagent/core/api_client.py b/eagent/core/api_client.py new file mode 100644 index 0000000..b0e9280 --- /dev/null +++ b/eagent/core/api_client.py @@ -0,0 +1,336 @@ +"""Model API adapter and model configuration.""" + +from __future__ import annotations + +import asyncio +import json +import os +import re +from collections.abc import AsyncGenerator +from typing import Any + +from eagent.core.errors import PromptTooLongError, classify_error, with_retry +from eagent.core.types import ( + Message, + ModelConfig, + StreamEvent, + SystemPromptBlock, + TextBlock, + ThinkingBlock, + TokenUsage, + Tool, + ToolUseBlock, +) + +try: # pragma: no cover - import availability depends on runtime environment + from anthropic import Anthropic +except Exception: # pragma: no cover + Anthropic = None # type: ignore[assignment] + +MODEL_CONFIGS: dict[str, ModelConfig] = { + "claude-sonnet-4-20250514": ModelConfig( + model="claude-sonnet-4-20250514", + context_window=200_000, + max_output_tokens=16_384, + supports_thinking=True, + supports_caching=True, + price_per_input_token=3 / 1_000_000, + price_per_output_token=15 / 1_000_000, + price_per_cache_read=0.3 / 1_000_000, + price_per_cache_write=3.75 / 1_000_000, + ), + "claude-opus-4-20250514": ModelConfig( + model="claude-opus-4-20250514", + context_window=200_000, + max_output_tokens=16_384, + supports_thinking=True, + supports_caching=True, + price_per_input_token=15 / 1_000_000, + price_per_output_token=75 / 1_000_000, + price_per_cache_read=1.5 / 1_000_000, + price_per_cache_write=18.75 / 1_000_000, + ), + "claude-haiku-4-5-20251001": ModelConfig( + model="claude-haiku-4-5-20251001", + context_window=200_000, + max_output_tokens=16_384, + supports_thinking=False, + supports_caching=True, + price_per_input_token=0.8 / 1_000_000, + price_per_output_token=4 / 1_000_000, + price_per_cache_read=0.08 / 1_000_000, + price_per_cache_write=1 / 1_000_000, + ), +} +MODEL_CONFIGS["sonnet"] = MODEL_CONFIGS["claude-sonnet-4-20250514"] +MODEL_CONFIGS["opus"] = MODEL_CONFIGS["claude-opus-4-20250514"] +MODEL_CONFIGS["haiku"] = MODEL_CONFIGS["claude-haiku-4-5-20251001"] + +_last_usage = TokenUsage() + + +def get_last_usage() -> TokenUsage: + return TokenUsage( + input_tokens=_last_usage.input_tokens, + output_tokens=_last_usage.output_tokens, + cache_read_tokens=_last_usage.cache_read_tokens, + cache_creation_tokens=_last_usage.cache_creation_tokens, + ) + + +def set_last_usage(usage: TokenUsage) -> None: + global _last_usage + _last_usage = usage + + +def get_model_config(model: str) -> ModelConfig: + if model in MODEL_CONFIGS: + return MODEL_CONFIGS[model] + for key, cfg in MODEL_CONFIGS.items(): + if key in model or model in key: + return cfg + base = MODEL_CONFIGS["sonnet"] + return ModelConfig( + model=model, + context_window=base.context_window, + max_output_tokens=base.max_output_tokens, + supports_thinking=base.supports_thinking, + supports_caching=base.supports_caching, + price_per_input_token=base.price_per_input_token, + price_per_output_token=base.price_per_output_token, + price_per_cache_read=base.price_per_cache_read, + price_per_cache_write=base.price_per_cache_write, + ) + + +def _message_to_api(message: Message) -> dict[str, Any]: + content: list[dict[str, Any]] = [] + for block in message.content: + if block.type == "text": + content.append({"type": "text", "text": block.text}) + elif block.type == "tool_use": + content.append( + { + "type": "tool_use", + "id": block.id, + "name": block.name, + "input": block.input, + } + ) + elif block.type == "tool_result": + content.append( + { + "type": "tool_result", + "tool_use_id": block.tool_use_id, + "content": block.content, + "is_error": block.is_error, + } + ) + elif block.type == "thinking": + content.append({"type": "thinking", "thinking": block.thinking}) + elif block.type == "redacted_thinking": + content.append({"type": "redacted_thinking", "data": block.data}) + elif block.type == "image": + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": block.source.media_type, + "data": block.source.data, + }, + } + ) + return {"role": message.role, "content": content} + + +def _tool_to_api(tool: Tool) -> dict[str, Any]: + if callable(tool.description): + description = tool.description(None) + else: + description = tool.description + return { + "name": tool.name, + "description": description, + "input_schema": tool.input_schema or {"type": "object", "properties": {}}, + } + + +def _extract_latest_user_text(messages: list[Message]) -> str: + for message in reversed(messages): + if message.role != "user": + continue + text_parts = [b.text for b in message.content if getattr(b, "type", None) == "text"] + if text_parts: + return "\n".join(text_parts) + return "" + + +def _mock_tool_use(prompt: str) -> ToolUseBlock | None: + match = re.search(r"\[\[tool:(?P[A-Za-z0-9_\-]+)\s*(?P\{.*\})?\]\]", prompt) + if not match: + return None + name = match.group("name") + raw_input = match.group("input") or "{}" + try: + input_obj = json.loads(raw_input) + except json.JSONDecodeError: + input_obj = {} + return ToolUseBlock(type="tool_use", id="mock-tool-use-1", name=name, input=input_obj) + + +async def _mock_call( + messages: list[Message], + tools: list[Tool], +) -> AsyncGenerator[StreamEvent, None]: + latest = _extract_latest_user_text(messages) + tool_use = _mock_tool_use(latest) + if tool_use and not any(t.name == tool_use.name for t in tools): + tool_use = None + + base_text = "[mock] RTE-AI received your request." + if tool_use: + base_text += f" Requesting tool {tool_use.name}." + for chunk in [base_text]: + yield {"type": "assistant_text", "text": chunk} + + content = [TextBlock(type="text", text=base_text)] + if tool_use: + content.append(tool_use) + yield {"type": "tool_use", "tool_use": tool_use} + + message = Message(role="assistant", content=content) + usage = TokenUsage( + input_tokens=max(1, len(latest) // 4), output_tokens=max(1, len(base_text) // 4) + ) + set_last_usage(usage) + yield {"type": "assistant_message", "message": message} + yield {"type": "usage", "usage": usage} + yield {"type": "turn_complete", "stop_reason": "tool_use" if tool_use else "end_turn"} + + +async def _anthropic_call( + messages: list[Message], + tools: list[Tool], + model_config: ModelConfig, + system_prompt_blocks: list[SystemPromptBlock], + api_key: str, + api_base_url: str | None, + enable_thinking: bool, + thinking_budget: int | None, +) -> AsyncGenerator[StreamEvent, None]: + if Anthropic is None: + async for event in _mock_call(messages, tools): + yield event + return + + system_text = "\n\n".join(block.text for block in system_prompt_blocks) + api_messages = [_message_to_api(msg) for msg in messages] + api_tools = [_tool_to_api(tool) for tool in tools] + client = Anthropic(api_key=api_key, base_url=api_base_url or os.getenv("ANTHROPIC_BASE_URL")) + + request: dict[str, Any] = { + "model": model_config.model, + "max_tokens": model_config.max_output_tokens, + "system": system_text, + "messages": api_messages, + } + if api_tools: + request["tools"] = api_tools + + if enable_thinking and model_config.supports_thinking: + request["thinking"] = { + "type": "enabled", + "budget_tokens": thinking_budget or 10_000, + } + + async def _invoke(_: int) -> Any: + return await asyncio.to_thread(lambda: client.messages.create(**request)) + + try: + response = await with_retry(_invoke) + except Exception as raw: + error = classify_error(raw) + if isinstance(error, PromptTooLongError): + raise + yield {"type": "error", "error": error} + return + + content_blocks: list[Any] = [] + for block in response.content: + if block.type == "text": + text = block.text + content_blocks.append(TextBlock(type="text", text=text)) + yield {"type": "assistant_text", "text": text} + elif block.type == "tool_use": + tool_use = ToolUseBlock( + type="tool_use", + id=block.id, + name=block.name, + input=block.input or {}, + ) + content_blocks.append(tool_use) + yield {"type": "tool_use", "tool_use": tool_use} + elif block.type == "thinking": + thinking = ThinkingBlock(type="thinking", thinking=getattr(block, "thinking", "")) + content_blocks.append(thinking) + yield {"type": "thinking", "text": thinking.thinking} + + message = Message(role="assistant", content=content_blocks) + usage = TokenUsage( + input_tokens=( + getattr(response.usage, "input_tokens", 0) if getattr(response, "usage", None) else 0 + ), + output_tokens=( + getattr(response.usage, "output_tokens", 0) if getattr(response, "usage", None) else 0 + ), + cache_read_tokens=( + getattr(response.usage, "cache_read_input_tokens", 0) + if getattr(response, "usage", None) + else 0 + ), + cache_creation_tokens=( + getattr(response.usage, "cache_creation_input_tokens", 0) + if getattr(response, "usage", None) + else 0 + ), + ) + set_last_usage(usage) + + yield {"type": "assistant_message", "message": message} + if usage.input_tokens or usage.output_tokens: + yield {"type": "usage", "usage": usage} + yield { + "type": "turn_complete", + "stop_reason": getattr(response, "stop_reason", "end_turn") or "end_turn", + } + + +async def call_model( + messages: list[Message], + tools: list[Tool], + model_config: ModelConfig, + system_prompt_blocks: list[SystemPromptBlock], + api_key: str, + api_base_url: str | None = None, + enable_thinking: bool = False, + thinking_budget: int | None = None, + abort_signal: Any | None = None, +) -> AsyncGenerator[StreamEvent, None]: + del abort_signal + if os.getenv("ENV_AGENT_MOCK", "").lower() in {"1", "true", "yes"} or not api_key: + async for event in _mock_call(messages, tools): + yield event + return + + async for event in _anthropic_call( + messages=messages, + tools=tools, + model_config=model_config, + system_prompt_blocks=system_prompt_blocks, + api_key=api_key, + api_base_url=api_base_url, + enable_thinking=enable_thinking, + thinking_budget=thinking_budget, + ): + yield event diff --git a/eagent/core/errors.py b/eagent/core/errors.py new file mode 100644 index 0000000..eaa1ced --- /dev/null +++ b/eagent/core/errors.py @@ -0,0 +1,117 @@ +"""Error classification and retry utilities.""" + +from __future__ import annotations + +import asyncio +import random +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import TypeVar + + +class EnvAgentError(Exception): + """Base eagent error.""" + + +class PromptTooLongError(EnvAgentError): + pass + + +@dataclass +class RateLimitError(EnvAgentError): + retry_after_ms: int = 5000 + + +class OverloadedError(EnvAgentError): + pass + + +class AuthenticationError(EnvAgentError): + pass + + +class NetworkError(EnvAgentError): + pass + + +class AbortError(EnvAgentError): + pass + + +def classify_error(error: Exception) -> EnvAgentError: + msg = str(error).lower() + status = getattr(error, "status_code", None) or getattr(error, "status", None) + + if isinstance(error, EnvAgentError): + return error + + if status in (401, 403): + return AuthenticationError(str(error)) + + if status == 429: + retry_after = 5000 + headers = getattr(error, "headers", None) + if headers: + value = headers.get("retry-after") if hasattr(headers, "get") else None + if value: + try: + retry_after = max(1000, int(float(value) * 1000)) + except Exception: + retry_after = 5000 + return RateLimitError(str(error), retry_after_ms=retry_after) + + if status == 529: + return OverloadedError(str(error)) + + if "prompt" in msg and "long" in msg: + return PromptTooLongError(str(error)) + + if any(part in msg for part in ["network", "timed out", "connection", "socket", "fetch"]): + return NetworkError(str(error)) + + return EnvAgentError(str(error)) + + +T = TypeVar("T") + + +async def with_retry( + func: Callable[[int], Awaitable[T]], + max_retries: int = 5, + initial_delay_ms: int = 1000, + max_delay_ms: int = 60_000, +) -> T: + """Run coroutine with retry strategy for retryable errors.""" + + consecutive_overloaded = 0 + last_error: EnvAgentError | None = None + + for attempt in range(max_retries + 1): + try: + return await func(attempt) + except Exception as exc: # pragma: no cover - path exercised in tests + err = classify_error(exc) + last_error = err + + if isinstance(err, (AuthenticationError, PromptTooLongError, AbortError)): + raise err + + if attempt >= max_retries: + raise err + + if isinstance(err, RateLimitError): + delay_ms = err.retry_after_ms + else: + delay_ms = min(max_delay_ms, int(initial_delay_ms * (2**attempt))) + + if isinstance(err, OverloadedError): + consecutive_overloaded += 1 + if consecutive_overloaded >= 3: + raise err + else: + consecutive_overloaded = 0 + + jitter = 0.8 + (random.random() * 0.4) + await asyncio.sleep((delay_ms * jitter) / 1000.0) + + raise last_error or EnvAgentError("retry failed") diff --git a/eagent/core/streaming_executor.py b/eagent/core/streaming_executor.py new file mode 100644 index 0000000..246c127 --- /dev/null +++ b/eagent/core/streaming_executor.py @@ -0,0 +1,286 @@ +"""Streaming tool executor with safe parallel batching.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any, Literal + +from eagent.core.types import StreamEvent, Tool, ToolContext, ToolResult, ToolUseBlock + +MAX_CONCURRENCY = 10 + + +@dataclass +class QueuedTool: + block: ToolUseBlock + tool: Tool + is_safe: bool + result: ToolResult | None = None + hook_debug_lines: list[str] = field(default_factory=list) + + +@dataclass +class ToolBatch: + is_concurrency_safe: bool + items: list[QueuedTool] + + +def _partition_tool_calls(items: list[QueuedTool]) -> list[ToolBatch]: + batches: list[ToolBatch] = [] + current_safe: list[QueuedTool] = [] + + for item in items: + if item.is_safe: + current_safe.append(item) + continue + + if current_safe: + batches.append(ToolBatch(is_concurrency_safe=True, items=current_safe)) + current_safe = [] + batches.append(ToolBatch(is_concurrency_safe=False, items=[item])) + + if current_safe: + batches.append(ToolBatch(is_concurrency_safe=True, items=current_safe)) + return batches + + +async def _execute_single(item: QueuedTool, context: ToolContext) -> ToolResult: + def _append_prompt_appends(values: list[str]) -> None: + if not values or context.on_hook_prompt_append is None: + return + for value in values: + context.on_hook_prompt_append(value) + + async def _run_hooks( + event: Literal["pre_tool_use", "post_tool_use", "on_error"], + *, + target: str, + variables: dict[str, Any], + allow_prompt_append: bool = True, + ) -> tuple[bool, str | None]: + runtime = context.hook_runtime + if runtime is None: + return False, None + + outcome = await runtime.run( + event, + target=target, + variables=variables, + cwd=context.cwd, + dev_mode=context.dev_mode, + allow_prompt_append=allow_prompt_append, + ) + item.hook_debug_lines.extend(outcome.debug_lines) + _append_prompt_appends(outcome.prompt_appends) + return outcome.aborted, outcome.abort_reason + + target_name = item.block.name + base_vars: dict[str, Any] = { + "tool_name": item.block.name, + "tool_use_id": item.block.id, + "tool_input": str(item.block.input), + "session_id": context.session_id, + } + + try: + before_aborted, before_reason = await _run_hooks( + "pre_tool_use", + target=target_name, + variables=base_vars, + allow_prompt_append=True, + ) + if before_aborted: + result = ToolResult( + result=before_reason or f"Hook aborted before {item.block.name}.", + is_error=True, + ) + item.result = result + await _run_hooks( + "on_error", + target=target_name, + variables={ + **base_vars, + "source": "pre_tool_use", + "error": result.result, + }, + allow_prompt_append=True, + ) + return result + + result = await item.tool.call(item.block.input, context) + if len(result.result) > item.tool.max_result_size_chars: + result.result = ( + result.result[: item.tool.max_result_size_chars] + + "\n\n[Output truncated: was " + + f"{len(result.result)} chars, " + + f"limit {item.tool.max_result_size_chars}]" + ) + + after_aborted, after_reason = await _run_hooks( + "post_tool_use", + target=target_name, + variables={ + **base_vars, + "tool_result": result.result, + "tool_is_error": result.is_error, + }, + allow_prompt_append=True, + ) + if after_aborted: + result = ToolResult( + result=after_reason or f"Hook aborted after {item.block.name}.", + is_error=True, + ) + + if result.is_error: + await _run_hooks( + "on_error", + target=target_name, + variables={ + **base_vars, + "source": "tool_result", + "error": result.result, + "tool_result": result.result, + }, + allow_prompt_append=True, + ) + + item.result = result + return result + except Exception as exc: + result = ToolResult(result=f"Error executing {item.block.name}: {exc}", is_error=True) + await _run_hooks( + "on_error", + target=target_name, + variables={ + **base_vars, + "source": "tool_exception", + "error": str(exc), + }, + allow_prompt_append=True, + ) + item.result = result + return result + + +async def _execute_parallel_batch( + batch: ToolBatch, context: ToolContext +) -> AsyncGenerator[StreamEvent, None]: + semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + + async def run(item: QueuedTool) -> ToolResult: + async with semaphore: + return await _execute_single(item, context) + + tasks: list[asyncio.Task[ToolResult]] = [] + for item in batch.items: + yield { + "type": "tool_start", + "tool_use_id": item.block.id, + "tool_name": item.block.name, + "input": item.block.input, + } + tasks.append(asyncio.create_task(run(item))) + + # Preserve original order in emitted results. + for item, task in zip(batch.items, tasks, strict=True): + await task + assert item.result is not None + if context.dev_mode and item.hook_debug_lines: + for line in item.hook_debug_lines: + yield {"type": "hook_debug", "text": line} + yield { + "type": "tool_result", + "tool_use_id": item.block.id, + "tool_name": item.block.name, + "result": item.result.result, + "is_error": item.result.is_error, + } + + +async def _execute_serial_batch( + batch: ToolBatch, context: ToolContext +) -> AsyncGenerator[StreamEvent, None]: + item = batch.items[0] + yield { + "type": "tool_start", + "tool_use_id": item.block.id, + "tool_name": item.block.name, + "input": item.block.input, + } + await _execute_single(item, context) + assert item.result is not None + if context.dev_mode and item.hook_debug_lines: + for line in item.hook_debug_lines: + yield {"type": "hook_debug", "text": line} + yield { + "type": "tool_result", + "tool_use_id": item.block.id, + "tool_name": item.block.name, + "result": item.result.result, + "is_error": item.result.is_error, + } + + +def _unknown_tool(name: str) -> Tool: + async def _call(_input, _context): + return ToolResult(result=f"Unknown tool: {name}", is_error=True) + + return Tool( + name=name, + description="", + input_schema={"type": "object"}, + call=_call, + prompt=lambda: "", + is_concurrency_safe=lambda _i: False, + is_read_only=lambda _i: False, + max_result_size_chars=30_000, + user_facing_name=lambda _i: name, + ) + + +async def execute_tools( + tool_use_blocks: list[ToolUseBlock], + tools: list[Tool], + context: ToolContext, +) -> AsyncGenerator[StreamEvent, None]: + tool_map = {tool.name: tool for tool in tools} + + queue: list[QueuedTool] = [] + for block in tool_use_blocks: + tool = tool_map.get(block.name) or _unknown_tool(block.name) + queue.append( + QueuedTool( + block=block, + tool=tool, + is_safe=bool(tool.is_concurrency_safe(block.input)), + ) + ) + + for batch in _partition_tool_calls(queue): + if batch.is_concurrency_safe: + async for event in _execute_parallel_batch(batch, context): + yield event + else: + async for event in _execute_serial_batch(batch, context): + yield event + + +async def execute_tools_collect( + tool_use_blocks: list[ToolUseBlock], + tools: list[Tool], + context: ToolContext, +) -> list[dict[str, object]]: + results: list[dict[str, object]] = [] + async for event in execute_tools(tool_use_blocks, tools, context): + if event["type"] == "tool_result": + results.append( + { + "tool_use_id": event["tool_use_id"], + "tool_name": event["tool_name"], + "result": ToolResult(result=event["result"], is_error=event["is_error"]), + } + ) + return results diff --git a/eagent/core/types.py b/eagent/core/types.py new file mode 100644 index 0000000..6125710 --- /dev/null +++ b/eagent/core/types.py @@ -0,0 +1,325 @@ +"""Shared runtime types for eagent.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol, TypedDict + +MessageRole = Literal["user", "assistant"] +PermissionMode = Literal["default", "plan", "acceptEdits", "bypassPermissions"] +PermissionBehavior = Literal["allow", "deny", "ask"] + + +@dataclass +class TextBlock: + type: Literal["text"] + text: str + + +@dataclass +class ToolUseBlock: + type: Literal["tool_use"] + id: str + name: str + input: dict[str, Any] + + +@dataclass +class ToolResultBlock: + type: Literal["tool_result"] + tool_use_id: str + content: str | list[ContentBlock] + is_error: bool = False + + +@dataclass +class ThinkingBlock: + type: Literal["thinking"] + thinking: str + + +@dataclass +class RedactedThinkingBlock: + type: Literal["redacted_thinking"] + data: str + + +@dataclass +class ImageSource: + type: Literal["base64"] + media_type: str + data: str + + +@dataclass +class ImageBlock: + type: Literal["image"] + source: ImageSource + + +ContentBlock = ( + TextBlock + | ToolUseBlock + | ToolResultBlock + | ThinkingBlock + | RedactedThinkingBlock + | ImageBlock +) + + +@dataclass +class Message: + role: MessageRole + content: list[ContentBlock] + id: str | None = None + + +@dataclass +class TokenUsage: + input_tokens: int = 0 + output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_creation_tokens: int = 0 + + +@dataclass +class ToolResult: + result: str + is_error: bool = False + + +@dataclass +class PermissionDecision: + behavior: PermissionBehavior + message: str | None = None + updated_input: dict[str, Any] | None = None + + +@dataclass +class FileState: + content: str + timestamp: float + offset: int | None = None + limit: int | None = None + is_partial_view: bool = False + + +class FileStateCache(Protocol): + def get(self, path: str) -> FileState | None: ... + + def set(self, path: str, state: FileState) -> None: ... + + def has(self, path: str) -> bool: ... + + def delete(self, path: str) -> None: ... + + def keys(self) -> Iterable[str]: ... + + def clone(self) -> FileStateCache: ... + + def merge(self, other: FileStateCache) -> None: ... + + @property + def size(self) -> int: ... + + +@dataclass +class FileHistoryBackup: + backup_file_name: str | None + version: int + backup_time: float + + +@dataclass +class FileHistorySnapshot: + message_id: str + tracked_file_backups: dict[str, FileHistoryBackup] = field(default_factory=dict) + timestamp: float = 0.0 + + +@dataclass +class FileHistoryState: + snapshots: list[FileHistorySnapshot] = field(default_factory=list) + tracked_files: set[str] = field(default_factory=set) + snapshot_sequence: int = 0 + + +@dataclass +class ToolContext: + cwd: str + read_file_state: FileStateCache + file_history: FileHistoryState + modified_files: set[str] + session_id: str + permission_mode: PermissionMode + on_permission_request: Callable[[str, Any, str], Awaitable[PermissionDecision]] + abort_signal: Any | None = None + hook_runtime: Any | None = None + on_hook_prompt_append: Callable[[str], None] | None = None + dev_mode: bool = False + + +class ToolDef(Protocol): + name: str + description: str | Callable[[dict[str, Any] | None], str] + input_schema: dict[str, Any] + + async def call(self, input_data: dict[str, Any], context: ToolContext) -> ToolResult: ... + + def prompt(self) -> str: ... + + def is_concurrency_safe(self, input_data: dict[str, Any]) -> bool: ... + + def is_read_only(self, input_data: dict[str, Any]) -> bool: ... + + def user_facing_name(self, input_data: dict[str, Any]) -> str: ... + + +@dataclass +class Tool: + name: str + description: str | Callable[[dict[str, Any] | None], str] + input_schema: dict[str, Any] + call: Callable[[dict[str, Any], ToolContext], Awaitable[ToolResult]] + prompt: Callable[[], str] = lambda: "" + is_concurrency_safe: Callable[[dict[str, Any]], bool] = lambda _i: False + is_read_only: Callable[[dict[str, Any]], bool] = lambda _i: False + max_result_size_chars: int = 30_000 + user_facing_name: Callable[[dict[str, Any]], str] = lambda _i: "" + + +@dataclass +class PermissionRule: + tool: str + behavior: Literal["allow", "deny"] + source: Literal["session", "project", "user"] + content: str | None = None + + +@dataclass +class ModelConfig: + model: str + context_window: int + max_output_tokens: int + supports_thinking: bool + supports_caching: bool + price_per_input_token: float + price_per_output_token: float + price_per_cache_read: float + price_per_cache_write: float + + +@dataclass +class SystemPromptBlock: + type: Literal["text"] + text: str + cache_control: dict[str, str] | None = None + + +@dataclass +class QueryParams: + messages: list[Message] + tools: list[Tool] + model_config: ModelConfig + system_prompt_blocks: list[SystemPromptBlock] + permission_mode: PermissionMode + api_key: str + cwd: str + session_id: str + on_permission_request: Callable[[str, Any, str], Awaitable[PermissionDecision]] + read_file_state: FileStateCache + file_history: FileHistoryState + max_turns: int = 200 + abort_signal: Any | None = None + enable_thinking: bool = False + thinking_budget: int | None = None + api_key_override: str | None = None + api_base_url: str | None = None + hook_runtime: Any | None = None + dev_mode: bool = False + + +@dataclass +class CostTracker: + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cache_read_tokens: int = 0 + total_cache_creation_tokens: int = 0 + turns: int = 0 + + def add(self, usage: TokenUsage) -> None: + self.total_input_tokens += int(usage.input_tokens or 0) + self.total_output_tokens += int(usage.output_tokens or 0) + self.total_cache_read_tokens += int(usage.cache_read_tokens or 0) + self.total_cache_creation_tokens += int(usage.cache_creation_tokens or 0) + self.turns += 1 + + def total_cost_usd(self, config: ModelConfig) -> float: + return ( + self.total_input_tokens * config.price_per_input_token + + self.total_output_tokens * config.price_per_output_token + + self.total_cache_read_tokens * config.price_per_cache_read + + self.total_cache_creation_tokens * config.price_per_cache_write + ) + + +class StreamEvent(TypedDict, total=False): + type: Literal[ + "assistant_text", + "assistant_message", + "tool_use", + "tool_result", + "tool_start", + "turn_complete", + "usage", + "compact", + "error", + "max_turns_reached", + "thinking", + "hook_debug", + ] + text: str + message: Message + tool_use: ToolUseBlock + tool_use_id: str + tool_name: str + result: str + is_error: bool + input: dict[str, Any] + stop_reason: str + usage: TokenUsage + old_tokens: int + new_tokens: int + error: Exception + max_turns: int + + +@dataclass +class CommandContext: + messages: list[Message] + tools: list[Tool] + model_config: ModelConfig + cwd: str + session_id: str + cost_tracker: CostTracker + file_history: FileHistoryState + read_file_state: FileStateCache + permission_mode: PermissionMode + set_permission_mode: Callable[[PermissionMode], None] + set_model: Callable[[str], str] + clear_messages: Callable[[], None] + compact: Callable[[], Awaitable[None]] + resume_session: Callable[[str], Awaitable[str | None]] + send_prompt: Callable[[str], None] + set_input_draft: Callable[[str], None] | None = None + interactive: bool = False + new_session: Callable[[], Awaitable[str]] | None = None + dev_mode: bool = False + + +class SlashCommand(Protocol): + name: str + description: str + + async def execute(self, args: str, context: CommandContext) -> str | None: ... diff --git a/eagent/files/__init__.py b/eagent/files/__init__.py new file mode 100644 index 0000000..0bc32f9 --- /dev/null +++ b/eagent/files/__init__.py @@ -0,0 +1,22 @@ +"""File state modules.""" + +from eagent.files.atomic_write import atomic_write +from eagent.files.cache import LruFileStateCache, create_file_state_cache +from eagent.files.history import ( + create_file_history_state, + get_diff_stats, + make_snapshot, + rewind, + track_edit, +) + +__all__ = [ + "atomic_write", + "LruFileStateCache", + "create_file_state_cache", + "create_file_history_state", + "track_edit", + "make_snapshot", + "rewind", + "get_diff_stats", +] diff --git a/eagent/files/atomic_write.py b/eagent/files/atomic_write.py new file mode 100644 index 0000000..5f6b059 --- /dev/null +++ b/eagent/files/atomic_write.py @@ -0,0 +1,48 @@ +"""Atomic file write helper.""" + +from __future__ import annotations + +import os +import tempfile +from pathlib import Path + + +def _file_mode(path: Path) -> int | None: + try: + return path.stat().st_mode + except OSError: + return None + + +def atomic_write(file_path: str, content: str) -> None: + target = Path(file_path).expanduser().resolve() + target.parent.mkdir(parents=True, exist_ok=True) + + original_mode = _file_mode(target) + + tmp_fd, tmp_name = tempfile.mkstemp(prefix=f"{target.name}.tmp.", dir=str(target.parent)) + tmp = Path(tmp_name) + try: + with os.fdopen(tmp_fd, "w", encoding="utf-8") as f: + f.write(content) + if original_mode is not None: + try: + os.chmod(tmp, original_mode) + except OSError: + pass + + try: + os.replace(tmp, target) + except OSError: + target.write_text(content, encoding="utf-8") + if original_mode is not None: + try: + os.chmod(target, original_mode) + except OSError: + pass + if tmp.exists(): + tmp.unlink() + except Exception: + if tmp.exists(): + tmp.unlink() + raise diff --git a/eagent/files/cache.py b/eagent/files/cache.py new file mode 100644 index 0000000..6496b65 --- /dev/null +++ b/eagent/files/cache.py @@ -0,0 +1,96 @@ +"""LRU file state cache.""" + +from __future__ import annotations + +from collections import OrderedDict +from pathlib import Path + +from eagent.core.types import FileState + +MAX_ENTRIES = 100 +MAX_TOTAL_BYTES = 25 * 1024 * 1024 + + +class LruFileStateCache: + def __init__( + self, max_entries: int = MAX_ENTRIES, max_total_bytes: int = MAX_TOTAL_BYTES + ) -> None: + self._max_entries = max_entries + self._max_total_bytes = max_total_bytes + self._store: OrderedDict[str, tuple[FileState, int]] = OrderedDict() + self._bytes = 0 + + def _key(self, path: str) -> str: + return str(Path(path).expanduser().resolve()) + + def _size(self, state: FileState) -> int: + return len(state.content.encode("utf-8")) + + def _evict(self) -> None: + while self._store and ( + len(self._store) > self._max_entries or self._bytes > self._max_total_bytes + ): + _k, (_s, b) = self._store.popitem(last=False) + self._bytes -= b + + def get(self, path: str) -> FileState | None: + key = self._key(path) + item = self._store.get(key) + if item is None: + return None + self._store.move_to_end(key) + return item[0] + + def set(self, path: str, state: FileState) -> None: + key = self._key(path) + old = self._store.pop(key, None) + if old: + self._bytes -= old[1] + size = self._size(state) + self._store[key] = (state, size) + self._bytes += size + self._evict() + + def has(self, path: str) -> bool: + return self._key(path) in self._store + + def delete(self, path: str) -> None: + key = self._key(path) + old = self._store.pop(key, None) + if old: + self._bytes -= old[1] + + def keys(self): + return list(self._store.keys()) + + def clone(self) -> LruFileStateCache: + cloned = LruFileStateCache(self._max_entries, self._max_total_bytes) + for key, (state, _bytes) in self._store.items(): + cloned.set( + key, + FileState( + content=state.content, + timestamp=state.timestamp, + offset=state.offset, + limit=state.limit, + is_partial_view=state.is_partial_view, + ), + ) + return cloned + + def merge(self, other) -> None: + for key in other.keys(): + other_state = other.get(key) + if other_state is None: + continue + current = self.get(key) + if current is None or other_state.timestamp > current.timestamp: + self.set(key, other_state) + + @property + def size(self) -> int: + return len(self._store) + + +def create_file_state_cache() -> LruFileStateCache: + return LruFileStateCache() diff --git a/eagent/files/history.py b/eagent/files/history.py new file mode 100644 index 0000000..6207ef6 --- /dev/null +++ b/eagent/files/history.py @@ -0,0 +1,138 @@ +"""File history snapshot helpers.""" + +from __future__ import annotations + +import hashlib +import time +from pathlib import Path + +from eagent.core.types import FileHistoryBackup, FileHistorySnapshot, FileHistoryState +from eagent.paths import env_root + +MAX_SNAPSHOTS = 100 + + +def _history_dir(session_id: str) -> Path: + return env_root() / "file-history" / session_id + + +def _path_hash(file_path: str) -> str: + return hashlib.sha256(file_path.encode("utf-8")).hexdigest()[:16] + + +def create_file_history_state() -> FileHistoryState: + return FileHistoryState() + + +async def track_edit(state: FileHistoryState, file_path: str, session_id: str) -> None: + _ = session_id + state.tracked_files.add(str(Path(file_path).expanduser().resolve())) + + +async def make_snapshot( + state: FileHistoryState, message_id: str, session_id: str +) -> FileHistorySnapshot: + now = time.time() + root = _history_dir(session_id) + root.mkdir(parents=True, exist_ok=True) + + backups: dict[str, FileHistoryBackup] = {} + for file_path in sorted(state.tracked_files): + p = Path(file_path) + + version = 1 + for snap in reversed(state.snapshots): + prev = snap.tracked_file_backups.get(file_path) + if prev is not None: + version = prev.version + 1 + break + + if p.exists() and p.is_file(): + backup_name = f"{_path_hash(file_path)}@v{version}" + (root / backup_name).write_bytes(p.read_bytes()) + backups[file_path] = FileHistoryBackup( + backup_file_name=backup_name, + version=version, + backup_time=now, + ) + else: + backups[file_path] = FileHistoryBackup( + backup_file_name=None, + version=version, + backup_time=now, + ) + + snapshot = FileHistorySnapshot( + message_id=message_id, tracked_file_backups=backups, timestamp=now + ) + state.snapshots.append(snapshot) + state.snapshot_sequence += 1 + if len(state.snapshots) > MAX_SNAPSHOTS: + state.snapshots = state.snapshots[-MAX_SNAPSHOTS:] + return snapshot + + +async def rewind(state: FileHistoryState, snapshot_index: int, session_id: str) -> None: + if snapshot_index < 0 or snapshot_index >= len(state.snapshots): + raise ValueError("Invalid snapshot index") + + target = state.snapshots[snapshot_index] + root = _history_dir(session_id) + for file_path, backup in target.tracked_file_backups.items(): + p = Path(file_path) + if backup.backup_file_name is None: + if p.exists(): + p.unlink() + continue + + src = root / backup.backup_file_name + if src.exists(): + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(src.read_bytes()) + + state.snapshots = state.snapshots[: snapshot_index + 1] + + +async def get_diff_stats( + state: FileHistoryState, + snapshot_a: int, + snapshot_b: int, + session_id: str, +) -> list[dict[str, int | str]]: + if ( + snapshot_a < 0 + or snapshot_b < 0 + or snapshot_a >= len(state.snapshots) + or snapshot_b >= len(state.snapshots) + ): + raise ValueError("Invalid snapshot index") + + root = _history_dir(session_id) + a = state.snapshots[snapshot_a] + b = state.snapshots[snapshot_b] + files = set(a.tracked_file_backups) | set(b.tracked_file_backups) + + def _read(backup: FileHistoryBackup | None) -> str: + if backup is None or backup.backup_file_name is None: + return "" + path = root / backup.backup_file_name + if not path.exists(): + return "" + return path.read_text(encoding="utf-8", errors="replace") + + out: list[dict[str, int | str]] = [] + for file_path in sorted(files): + left = _read(a.tracked_file_backups.get(file_path)) + right = _read(b.tracked_file_backups.get(file_path)) + if left == right: + continue + left_lines = set(left.splitlines()) + right_lines = set(right.splitlines()) + out.append( + { + "file": file_path, + "insertions": len([x for x in right_lines if x not in left_lines]), + "deletions": len([x for x in left_lines if x not in right_lines]), + } + ) + return out diff --git a/eagent/files/utils.py b/eagent/files/utils.py new file mode 100644 index 0000000..f46743b --- /dev/null +++ b/eagent/files/utils.py @@ -0,0 +1,55 @@ +"""File utility helpers used by tools and permissions.""" + +from __future__ import annotations + +import os +from pathlib import Path + + +def normalize_path(path: str) -> str: + return str(Path(path).expanduser().resolve()) + + +def normalize_line_endings(text: str) -> str: + return text.replace("\r\n", "\n") + + +def is_within_project(file_path: str, project_root: str) -> bool: + file_abs = Path(file_path).expanduser().resolve() + root_abs = Path(project_root).expanduser().resolve() + return file_abs == root_abs or root_abs in file_abs.parents + + +def detect_encoding(data: bytes) -> str: + if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xFE: + return "utf-16le" + if len(data) >= 3 and data[0] == 0xEF and data[1] == 0xBB and data[2] == 0xBF: + return "utf-8-sig" + return "utf-8" + + +def is_binary_data(data: bytes) -> bool: + sample = data[:8192] + return b"\x00" in sample + + +def is_binary_file(file_path: str) -> bool: + path = Path(file_path) + try: + data = path.read_bytes()[:8192] + return is_binary_data(data) + except Exception: + return False + + +def format_with_line_numbers(text: str, start_line: int = 1) -> str: + lines = text.split("\n") + width = len(str(start_line + len(lines) - 1)) if lines else 1 + return "\n".join(f"{str(i).rjust(width)}\t{line}" for i, line in enumerate(lines, start_line)) + + +def get_file_size(file_path: str) -> int: + try: + return os.path.getsize(file_path) + except OSError: + return -1 diff --git a/eagent/headless.py b/eagent/headless.py new file mode 100644 index 0000000..05172c7 --- /dev/null +++ b/eagent/headless.py @@ -0,0 +1,249 @@ +"""Headless SDK interface for eagent.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator, Callable +from dataclasses import dataclass +from typing import Any + +from eagent.context.git_context import get_git_context +from eagent.context.memory import load_agent_memory +from eagent.core.agent_loop import agent_loop +from eagent.core.api_client import get_model_config +from eagent.core.types import ( + Message, + PermissionDecision, + PermissionMode, + QueryParams, + StreamEvent, + TextBlock, + Tool, +) +from eagent.files.cache import create_file_state_cache +from eagent.files.history import create_file_history_state +from eagent.hooks import HookRuntime +from eagent.mcp.manager import initialize_mcp_servers +from eagent.prompt.system_prompt import build_system_prompt_blocks +from eagent.skills.skill_tool import set_skill_query_params +from eagent.tools.agent_tool import set_agent_query_params +from eagent.tools.registry import initialize_tools, register_dynamic_tools +from eagent.utils.cost import create_cost_tracker, summarize_cost + + +@dataclass +class AgentOptions: + api_key: str + model: str = "sonnet" + cwd: str = "." + max_turns: int = 200 + permission_mode: PermissionMode = "bypassPermissions" + enable_thinking: bool = False + thinking_budget: int | None = None + tools: list[Tool] | None = None + on_permission_request: Callable[[str, Any, str], Any] | None = None + enable_mcp: bool = False + + +class AgentInstance: + def __init__(self, options: AgentOptions) -> None: + self.options = options + self.cwd = options.cwd + self.model_config = get_model_config(options.model) + self.session_id = str(uuid.uuid4()) + self.cost_tracker = create_cost_tracker() + self.read_file_state = create_file_state_cache() + self.file_history = create_file_history_state() + self.messages: list[Message] = [] + self.system_prompt_blocks = [] + self.tools: list[Tool] = [] + self.hook_runtime = HookRuntime(self.cwd) + self._session_hooks_ran = False + self._session_end_hooks_ran = False + + def _append_user_message(self, text: str) -> None: + self.messages.append( + Message(role="user", content=[TextBlock(type="text", text=text)], id=str(uuid.uuid4())) + ) + + def _last_assistant_message(self) -> str: + for message in reversed(self.messages): + if message.role != "assistant": + continue + chunks: list[str] = [] + for block in message.content: + if isinstance(block, TextBlock): + text = block.text.strip() + if text: + chunks.append(text) + if chunks: + return "\n".join(chunks).strip() + return "" + + async def initialize(self) -> None: + agent_memory = await load_agent_memory(self.cwd) + git_context = await get_git_context(self.cwd) + self.system_prompt_blocks = build_system_prompt_blocks( + agent_memory, git_context, self.cwd, self.model_config.model + ) + + self.tools = self.options.tools or await initialize_tools(self.cwd) + if self.options.enable_mcp: + mcp_tools = await initialize_mcp_servers(self.cwd) + if mcp_tools: + register_dynamic_tools(mcp_tools) + self.tools.extend(mcp_tools) + + async def query(self, prompt: str) -> AsyncGenerator[StreamEvent, None]: + if not self._session_hooks_ran: + self._session_hooks_ran = True + session_outcome = await self.hook_runtime.run( + "session_start", + target=self.session_id, + variables={"session_id": self.session_id, "cwd": self.cwd}, + cwd=self.cwd, + dev_mode=False, + ) + if session_outcome.aborted: + yield { + "type": "error", + "error": Exception( + session_outcome.abort_reason or "Session start hook aborted." + ), + } + return + for extra_prompt in session_outcome.prompt_appends: + if extra_prompt.strip(): + self._append_user_message(extra_prompt) + + prompt_text = prompt.strip() + user_outcome = await self.hook_runtime.run( + "user_prompt_submit", + target=prompt_text if prompt_text else "(empty)", + variables={"prompt": prompt, "prompt_text": prompt_text, "cwd": self.cwd}, + cwd=self.cwd, + dev_mode=False, + ) + if user_outcome.aborted: + yield { + "type": "error", + "error": Exception(user_outcome.abort_reason or "User prompt hook aborted."), + } + return + for extra_prompt in user_outcome.prompt_appends: + if extra_prompt.strip(): + self._append_user_message(extra_prompt) + + self._append_user_message(prompt) + + async def _default_permission_handler( + _tool: str, _input: Any, _message: str + ) -> PermissionDecision: + return PermissionDecision(behavior="allow") + + on_permission_request = self.options.on_permission_request or _default_permission_handler + + params = QueryParams( + messages=self.messages, + tools=self.tools, + model_config=self.model_config, + system_prompt_blocks=self.system_prompt_blocks, + max_turns=self.options.max_turns, + permission_mode=self.options.permission_mode, + api_key=self.options.api_key, + cwd=self.cwd, + session_id=self.session_id, + on_permission_request=on_permission_request, + enable_thinking=self.options.enable_thinking, + thinking_budget=self.options.thinking_budget, + read_file_state=self.read_file_state, + file_history=self.file_history, + hook_runtime=self.hook_runtime, + dev_mode=False, + ) + + set_agent_query_params(params) + set_skill_query_params(params) + + stop_target = "turn_complete" + stop_error = "" + async for event in agent_loop(params): + if event["type"] == "usage": + self.cost_tracker.add(event["usage"]) + if event["type"] == "turn_complete": + stop_target = str(event.get("stop_reason") or "turn_complete") + elif event["type"] == "max_turns_reached": + stop_target = f"max_turns:{event.get('max_turns')}" + elif event["type"] == "error": + stop_target = "error" + stop_error = str(event.get("error") or "") + yield event + + stop_outcome = await self.hook_runtime.run( + "stop", + target=stop_target, + variables={ + "stop_reason": stop_target, + "error": stop_error, + "cwd": self.cwd, + "last_assistant_message": self._last_assistant_message(), + }, + cwd=self.cwd, + dev_mode=False, + ) + if stop_outcome.aborted: + yield { + "type": "error", + "error": Exception(stop_outcome.abort_reason or "Stop hook aborted."), + } + return + for extra_prompt in stop_outcome.prompt_appends: + if extra_prompt.strip(): + self._append_user_message(extra_prompt) + + return + + def get_messages(self) -> list[Message]: + return list(self.messages) + + def get_cost_summary(self) -> str: + return summarize_cost(self.cost_tracker, self.model_config) + + def get_session_id(self) -> str: + return self.session_id + + def reset(self) -> None: + self.messages = [] + self._session_hooks_ran = False + self._session_end_hooks_ran = False + + async def close(self) -> None: + if self._session_end_hooks_ran: + return + self._session_end_hooks_ran = True + await self.hook_runtime.run( + "session_end", + target=self.session_id, + variables={"session_id": self.session_id, "cwd": self.cwd}, + cwd=self.cwd, + dev_mode=False, + allow_prompt_append=False, + ) + + +async def create_agent(options: AgentOptions) -> AgentInstance: + agent = AgentInstance(options) + await agent.initialize() + return agent + + +async def one_shot(prompt: str, options: AgentOptions) -> dict[str, Any]: + agent = await create_agent(options) + text_parts: list[str] = [] + try: + async for event in agent.query(prompt): + if event["type"] == "assistant_text": + text_parts.append(event["text"]) + return {"text": "".join(text_parts), "messages": agent.get_messages()} + finally: + await agent.close() diff --git a/eagent/hooks/__init__.py b/eagent/hooks/__init__.py new file mode 100644 index 0000000..8aae2a5 --- /dev/null +++ b/eagent/hooks/__init__.py @@ -0,0 +1,5 @@ +"""Hook framework for lifecycle events.""" + +from eagent.hooks.runtime import HookRuntime + +__all__ = ["HookRuntime"] diff --git a/eagent/hooks/runtime.py b/eagent/hooks/runtime.py new file mode 100644 index 0000000..652d3e0 --- /dev/null +++ b/eagent/hooks/runtime.py @@ -0,0 +1,819 @@ +"""Hook loading and execution runtime.""" + +from __future__ import annotations + +import asyncio +import fnmatch +import json +import os +import re +from collections.abc import Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from eagent.paths import env_root +from eagent.skills.loader import parse_frontmatter + +HookEventName = Literal[ + "session_start", + "session_end", + "user_prompt_submit", + "pre_tool_use", + "post_tool_use", + "stop", + "before_command", + "after_command", + "on_error", +] +HookAction = Literal["bash", "prompt_append"] +HookFailureMode = Literal["continue", "abort"] + +HOOK_EVENTS: tuple[HookEventName, ...] = ( + "session_start", + "session_end", + "user_prompt_submit", + "pre_tool_use", + "post_tool_use", + "stop", + "before_command", + "after_command", + "on_error", +) +HOOK_EVENT_DIR_ALIASES: dict[HookEventName, tuple[str, ...]] = { + "session_start": ("session_start",), + "session_end": ("session_end",), + "user_prompt_submit": ("user_prompt_submit",), + "pre_tool_use": ("pre_tool_use", "before_tool"), + "post_tool_use": ("post_tool_use", "after_tool"), + "stop": ("stop",), + "before_command": ("before_command",), + "after_command": ("after_command",), + "on_error": ("on_error",), +} +PROJECT_HOOKS_DIR = Path(".agents") / "hooks" +USER_HOOKS_DIR = Path("hooks") +HOOKS_JSON_FILE_NAME = "hooks.json" +HOOKS_RUNTIME_DIR = Path(".agents") / "runtime" +DEFAULT_TIMEOUT_SECONDS = 120 +_VALID_ACTIONS = {"bash", "prompt_append"} +_VALID_FAILURE_MODES = {"continue", "abort"} +_BRACE_PATTERN = re.compile(r"{{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*}}") +_DOLLAR_PATTERN = re.compile(r"\$([a-zA-Z_][a-zA-Z0-9_]*)") +_DOLLAR_BRACE_PATTERN = re.compile(r"\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}") +_EXPORT_PATTERN = re.compile(r"^\s*export\s+([A-Za-z_][A-Za-z0-9_]*)=(.*)$") + +_CLAUDE_EVENT_BY_INTERNAL: dict[HookEventName, str] = { + "session_start": "SessionStart", + "session_end": "SessionEnd", + "user_prompt_submit": "UserPromptSubmit", + "pre_tool_use": "PreToolUse", + "post_tool_use": "PostToolUse", + "stop": "Stop", + "before_command": "BeforeCommand", + "after_command": "AfterCommand", + "on_error": "OnError", +} + +_EVENT_TOKEN_TO_INTERNAL: dict[str, HookEventName] = { + "sessionstart": "session_start", + "sessionend": "session_end", + "userpromptsubmit": "user_prompt_submit", + "pretooluse": "pre_tool_use", + "posttooluse": "post_tool_use", + "stop": "stop", + "beforecommand": "before_command", + "aftercommand": "after_command", + "onerror": "on_error", +} + + +@dataclass(frozen=True) +class HookDefinition: + """Single hook entry loaded from markdown or hooks.json.""" + + event: HookEventName + name: str + description: str + match: str + action: HookAction + on_failure: HookFailureMode + timeout_seconds: int + command: str | None + template: str + source_path: Path + source: Literal["markdown", "json"] = "markdown" + pass_stdin_json: bool = False + parse_stdout_decision: bool = False + + +@dataclass(frozen=True) +class HookCommandResult: + """Result of bash hook execution.""" + + return_code: int + stdout: str + stderr: str + timed_out: bool = False + + +@dataclass +class HookExecutionOutcome: + """Aggregated hook execution output.""" + + aborted: bool = False + abort_reason: str | None = None + prompt_appends: list[str] = field(default_factory=list) + debug_lines: list[str] = field(default_factory=list) + + +def _to_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + return str(value) + + +def _read_file(path: Path) -> str | None: + try: + text = path.read_text(encoding="utf-8") + except Exception: + return None + stripped = text.strip() + return stripped if stripped else None + + +def _extract_title(text: str) -> str: + for raw in text.splitlines(): + line = raw.strip() + if not line: + continue + if line.startswith("#"): + return line.lstrip("#").strip()[:100] + return line[:100] + return "Hook" + + +def _normalize_action(value: Any) -> HookAction | None: + text = str(value or "").strip().lower() + if text in _VALID_ACTIONS: + return text # type: ignore[return-value] + return None + + +def _normalize_failure_mode(value: Any) -> HookFailureMode | None: + text = str(value or "").strip().lower() + if text in _VALID_FAILURE_MODES: + return text # type: ignore[return-value] + return None + + +def _normalize_timeout(value: Any) -> int: + if isinstance(value, (int, float)): + timeout = int(value) + return timeout if timeout > 0 else DEFAULT_TIMEOUT_SECONDS + if isinstance(value, str): + stripped = value.strip() + if stripped.isdigit(): + timeout = int(stripped) + return timeout if timeout > 0 else DEFAULT_TIMEOUT_SECONDS + return DEFAULT_TIMEOUT_SECONDS + + +def _normalize_event_token(value: str) -> str: + return re.sub(r"[^a-z0-9]+", "", value.strip().lower()) + + +def _canonical_event_name(value: str) -> HookEventName | None: + token = _normalize_event_token(value) + if not token: + return None + return _EVENT_TOKEN_TO_INTERNAL.get(token) + + +def _claude_event_name(event: HookEventName) -> str: + return _CLAUDE_EVENT_BY_INTERNAL.get(event, event) + + +def _decode_export_value(raw_value: str) -> str: + value = raw_value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + inner = value[1:-1] + if value[0] == "'": + # Support shell-escaped single quotes from hooks that append to CLAUDE_ENV_FILE. + return inner.replace("'\"'\"'", "'") + return inner.replace('\\"', '"').replace("\\$", "$") + return value + + +def _load_exported_env(env_file: Path) -> dict[str, str]: + if not env_file.exists() or not env_file.is_file(): + return {} + loaded: dict[str, str] = {} + try: + content = env_file.read_text(encoding="utf-8") + except Exception: + return loaded + for raw_line in content.splitlines(): + match = _EXPORT_PATTERN.match(raw_line) + if not match: + continue + name = match.group(1) + loaded[name] = _decode_export_value(match.group(2)) + return loaded + + +def _ensure_runtime_paths(cwd: str) -> tuple[Path, Path]: + runtime_dir = (Path(cwd).resolve() / HOOKS_RUNTIME_DIR).resolve() + runtime_dir.mkdir(parents=True, exist_ok=True) + plugin_data_dir = runtime_dir / "plugin_data" + plugin_data_dir.mkdir(parents=True, exist_ok=True) + env_file = runtime_dir / "claude_env.sh" + env_file.parent.mkdir(parents=True, exist_ok=True) + if not env_file.exists(): + env_file.touch() + return env_file, plugin_data_dir + + +def _build_hook_env(cwd: str, context: Mapping[str, Any]) -> dict[str, str]: + env = os.environ.copy() + env_file, plugin_data_dir = _ensure_runtime_paths(cwd) + project_dir = str(Path(cwd).resolve()) + env.update( + { + "CLAUDE_PROJECT_DIR": project_dir, + "CLAUDE_PLUGIN_ROOT": project_dir, + "CLAUDE_PLUGIN_DATA": str(plugin_data_dir), + "CLAUDE_ENV_FILE": str(env_file), + } + ) + + explicit_map = { + "claude_project_dir": "CLAUDE_PROJECT_DIR", + "claude_plugin_root": "CLAUDE_PLUGIN_ROOT", + "claude_plugin_data": "CLAUDE_PLUGIN_DATA", + "claude_env_file": "CLAUDE_ENV_FILE", + } + for key, env_name in explicit_map.items(): + value = context.get(key) + if value is None: + continue + rendered = str(value).strip() + if rendered: + env[env_name] = rendered + + resolved_env_file = Path(env["CLAUDE_ENV_FILE"]).expanduser() + resolved_env_file.parent.mkdir(parents=True, exist_ok=True) + if not resolved_env_file.exists(): + resolved_env_file.touch() + env.update(_load_exported_env(resolved_env_file)) + + env["RTE_PROJECT_DIR"] = env["CLAUDE_PROJECT_DIR"] + env["RTE_PLUGIN_ROOT"] = env["CLAUDE_PLUGIN_ROOT"] + env["RTE_PLUGIN_DATA"] = env["CLAUDE_PLUGIN_DATA"] + return env + + +def _json_safe(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Mapping): + return {str(key): _json_safe(val) for key, val in value.items()} + if isinstance(value, (list, tuple, set)): + return [_json_safe(item) for item in value] + return str(value) + + +def _build_hook_stdin_payload( + event: HookEventName, target: str, context: Mapping[str, Any] +) -> str: + payload = {str(key): _json_safe(value) for key, value in context.items()} + payload.setdefault("hook_event_name", _claude_event_name(event)) + payload.setdefault("rte_hook_event_name", event) + payload.setdefault("target", target) + return json.dumps(payload, ensure_ascii=False) + + +def _parse_hook_decision(stdout_text: str) -> tuple[bool, str | None, str | None]: + stripped = stdout_text.strip() + if not stripped: + return False, None, None + + candidates: list[str] = [stripped] + lines = [line.strip() for line in stripped.splitlines() if line.strip()] + if len(lines) > 1: + candidates.extend(reversed(lines)) + + last_error: str | None = None + for candidate in candidates: + try: + payload = json.loads(candidate) + except json.JSONDecodeError as exc: + last_error = str(exc) + continue + + if not isinstance(payload, Mapping): + return False, None, None + decision = str(payload.get("decision") or "").strip().lower() + if decision != "block": + return False, None, None + reason = str(payload.get("reason") or "").strip() + return True, reason or "Hook requested block.", None + + return False, None, last_error + + +def _render_template(template: str, variables: Mapping[str, Any]) -> str: + rendered = template + lookup = {key: _to_string(value) for key, value in variables.items()} + + def replace_brace(match: re.Match[str]) -> str: + key = match.group(1) + if key not in lookup: + return match.group(0) + return lookup[key] + + def replace_dollar(match: re.Match[str]) -> str: + key = match.group(1) + if key not in lookup: + return match.group(0) + return lookup[key] + + rendered = _BRACE_PATTERN.sub(replace_brace, rendered) + rendered = _DOLLAR_BRACE_PATTERN.sub(replace_dollar, rendered) + rendered = _DOLLAR_PATTERN.sub(replace_dollar, rendered) + return rendered + + +def _hook_name_from_path(path: Path, event_dir: Path) -> str: + relative = path.relative_to(event_dir).with_suffix("") + return ":".join(relative.parts) + + +def _parse_hook_file( + path: Path, event: HookEventName, event_dir: Path +) -> tuple[HookDefinition | None, str | None]: + content = _read_file(path) + if not content: + return None, "hook file is empty" + + frontmatter, body = parse_frontmatter(content) + action = _normalize_action(frontmatter.get("action")) + if action is None: + return None, "frontmatter `action` must be one of: bash, prompt_append" + + on_failure = _normalize_failure_mode( + frontmatter.get("on_failure") or frontmatter.get("on-failure") + ) + if on_failure is None: + return None, "frontmatter `on_failure` must be one of: continue, abort" + + command_raw = frontmatter.get("command") + command = str(command_raw).strip() if isinstance(command_raw, str) else None + template = body.strip() + if action == "bash" and not command: + return None, "frontmatter `command` is required for bash action" + if action == "prompt_append" and not template: + return None, "prompt_append hook requires non-empty markdown body" + + description_raw = frontmatter.get("description") + description = ( + str(description_raw).strip() + if isinstance(description_raw, str) and description_raw.strip() + else _extract_title(template if action == "prompt_append" else command or "") + ) + match_raw = frontmatter.get("match") + match_pattern = ( + str(match_raw).strip() if isinstance(match_raw, str) and match_raw.strip() else "*" + ) + timeout_seconds = _normalize_timeout(frontmatter.get("timeout")) + + definition = HookDefinition( + event=event, + name=_hook_name_from_path(path, event_dir), + description=description, + match=match_pattern, + action=action, + on_failure=on_failure, + timeout_seconds=timeout_seconds, + command=command, + template=template, + source_path=path, + ) + return definition, None + + +def _load_hooks_json( + hooks_root: Path, + source_kind: str, + hooks_by_event: dict[HookEventName, list[HookDefinition]], + seen_keys: set[str], + load_errors: list[str], +) -> None: + hooks_json_path = hooks_root / HOOKS_JSON_FILE_NAME + if not hooks_json_path.exists() or not hooks_json_path.is_file(): + return + + try: + raw_payload = json.loads(hooks_json_path.read_text(encoding="utf-8")) + except Exception as exc: + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks file {hooks_json_path}: {exc}" + ) + return + + hooks_payload = raw_payload.get("hooks") if isinstance(raw_payload, Mapping) else None + if not isinstance(hooks_payload, Mapping): + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks file {hooks_json_path}: " + "`hooks` object is required" + ) + return + + for raw_event_name, groups in hooks_payload.items(): + if not isinstance(raw_event_name, str): + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks entry in {hooks_json_path}: " + "event name must be string" + ) + continue + + event = _canonical_event_name(raw_event_name) + if event is None: + load_errors.append( + f"[hooks] ignored unsupported {source_kind} hooks event " + f"`{raw_event_name}` in {hooks_json_path}" + ) + continue + + if isinstance(groups, list): + group_entries = groups + elif isinstance(groups, Mapping): + group_entries = [groups] + else: + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks event " + f"`{raw_event_name}` in {hooks_json_path}: expected array/object" + ) + continue + + for group_index, group in enumerate(group_entries): + if isinstance(group, Mapping): + hooks_list = group.get("hooks") + if isinstance(hooks_list, list): + command_entries = hooks_list + elif isinstance(hooks_list, Mapping): + command_entries = [hooks_list] + else: + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks group " + f"`{raw_event_name}[{group_index}]` in {hooks_json_path}: " + "`hooks` list is required" + ) + continue + elif isinstance(group, list): + command_entries = group + else: + load_errors.append( + f"[hooks] ignored invalid {source_kind} hooks group " + f"`{raw_event_name}[{group_index}]` in {hooks_json_path}: " + "expected object/array" + ) + continue + + for command_index, command_entry in enumerate(command_entries): + if not isinstance(command_entry, Mapping): + load_errors.append( + f"[hooks] ignored invalid {source_kind} command hook " + f"`{raw_event_name}[{group_index}][{command_index}]` in " + f"{hooks_json_path}: expected object" + ) + continue + + hook_type = str(command_entry.get("type") or "").strip().lower() + if hook_type != "command": + load_errors.append( + f"[hooks] ignored unsupported {source_kind} hook type " + f"`{hook_type or ''}` in {hooks_json_path}" + ) + continue + + command_raw = command_entry.get("command") + command = str(command_raw).strip() if isinstance(command_raw, str) else "" + if not command: + load_errors.append( + f"[hooks] ignored invalid {source_kind} command hook " + f"`{raw_event_name}[{group_index}][{command_index}]` in " + f"{hooks_json_path}: `command` is required" + ) + continue + + timeout_seconds = _normalize_timeout(command_entry.get("timeout")) + name_raw = command_entry.get("name") + name = ( + str(name_raw).strip() + if isinstance(name_raw, str) and str(name_raw).strip() + else f"json:{raw_event_name}:{group_index}:{command_index}" + ) + relative_key = ( + f"json:{_normalize_event_token(raw_event_name)}:{group_index}:{command_index}" + ) + key = f"{event}:{relative_key}" + if key in seen_keys: + continue + + hooks_by_event[event].append( + HookDefinition( + event=event, + name=name, + description=f"{raw_event_name} command hook", + match="*", + action="bash", + on_failure="continue", + timeout_seconds=timeout_seconds, + command=command, + template="", + source_path=hooks_json_path, + source="json", + pass_stdin_json=True, + parse_stdout_decision=True, + ) + ) + seen_keys.add(key) + + +def _load_hooks(cwd: str) -> tuple[dict[HookEventName, list[HookDefinition]], list[str]]: + project_root = Path(cwd).resolve() + roots = [ + ("project", project_root / PROJECT_HOOKS_DIR), + ("user", env_root() / USER_HOOKS_DIR), + ] + hooks_by_event: dict[HookEventName, list[HookDefinition]] = {event: [] for event in HOOK_EVENTS} + load_errors: list[str] = [] + seen_keys: set[str] = set() + + for source_kind, hooks_root in roots: + for event in HOOK_EVENTS: + aliases = HOOK_EVENT_DIR_ALIASES.get(event, (event,)) + for alias in aliases: + event_dir = hooks_root / alias + if not event_dir.exists() or not event_dir.is_dir(): + continue + + for hook_file in sorted(event_dir.rglob("*.md")): + relative = hook_file.relative_to(event_dir).with_suffix("").as_posix().lower() + key = f"{event}:{relative}" + if key in seen_keys: + continue + + definition, error = _parse_hook_file(hook_file, event, event_dir) + if definition is None: + load_errors.append( + f"[hooks] ignored invalid {source_kind} hook {hook_file}: {error}" + ) + continue + + hooks_by_event[event].append(definition) + seen_keys.add(key) + + _load_hooks_json( + hooks_root=hooks_root, + source_kind=source_kind, + hooks_by_event=hooks_by_event, + seen_keys=seen_keys, + load_errors=load_errors, + ) + + return hooks_by_event, load_errors + + +def _match_target(pattern: str, target: str) -> bool: + if not pattern: + return True + return fnmatch.fnmatch(target.lower(), pattern.lower()) + + +def _truncate_debug(text: str, limit: int = 180) -> str: + stripped = text.strip() + if len(stripped) <= limit: + return stripped + return stripped[:limit] + " ..." + + +async def _run_bash_hook( + command: str, + cwd: str, + timeout_seconds: int, + *, + env: Mapping[str, str] | None = None, + stdin_payload: str | None = None, +) -> HookCommandResult: + process = await asyncio.create_subprocess_shell( + command, + cwd=cwd, + env=dict(env or os.environ.copy()), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdin_bytes = stdin_payload.encode("utf-8") if stdin_payload is not None else None + try: + stdout_raw, stderr_raw = await asyncio.wait_for( + process.communicate(stdin_bytes), timeout=float(timeout_seconds) + ) + return HookCommandResult( + return_code=int(process.returncode or 0), + stdout=stdout_raw.decode("utf-8", errors="replace"), + stderr=stderr_raw.decode("utf-8", errors="replace"), + timed_out=False, + ) + except TimeoutError: + process.kill() + stdout_raw, stderr_raw = await process.communicate() + return HookCommandResult( + return_code=-1, + stdout=stdout_raw.decode("utf-8", errors="replace"), + stderr=stderr_raw.decode("utf-8", errors="replace"), + timed_out=True, + ) + + +class HookRuntime: + """Loads and executes hooks from project .agents/hooks and user ~/.env/hooks.""" + + def __init__(self, cwd: str) -> None: + self.cwd = str(Path(cwd).resolve()) + self._hooks_by_event, self._load_errors = _load_hooks(self.cwd) + self._load_errors_reported = False + + def reload(self) -> None: + """Reload hook files from disk.""" + self._hooks_by_event, self._load_errors = _load_hooks(self.cwd) + self._load_errors_reported = False + + def hooks_for_event(self, event: HookEventName) -> list[HookDefinition]: + return list(self._hooks_by_event.get(event, [])) + + async def run( + self, + event: HookEventName, + *, + target: str, + variables: Mapping[str, Any] | None = None, + cwd: str | None = None, + dev_mode: bool = False, + allow_prompt_append: bool = True, + ) -> HookExecutionOutcome: + outcome = HookExecutionOutcome() + hook_list = self._hooks_by_event.get(event, []) + run_cwd = str(Path(cwd or self.cwd).resolve()) + context: dict[str, Any] = { + "event": event, + "target": target, + "cwd": run_cwd, + } + if variables: + context.update(variables) + context.setdefault("session_id", "") + context.setdefault("hook_event_name", _claude_event_name(event)) + context.setdefault("rte_hook_event_name", event) + + hook_env = _build_hook_env(run_cwd, context) + hook_stdin_json = _build_hook_stdin_payload(event, target, context) + + if dev_mode and self._load_errors and not self._load_errors_reported: + outcome.debug_lines.extend(self._load_errors) + self._load_errors_reported = True + + if dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}: target={target!r}, loaded={len(hook_list)}" + ) + + matched = 0 + for hook in hook_list: + if not _match_target(hook.match, target): + if dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: skip match={hook.match!r}" + ) + continue + + matched += 1 + if dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: run action={hook.action}" + ) + + if hook.action == "prompt_append": + if not allow_prompt_append: + if dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: prompt_append ignored" + ) + if hook.on_failure == "abort": + outcome.aborted = True + outcome.abort_reason = ( + f"Hook {event}/{hook.name} aborted: prompt append is not allowed" + ) + break + continue + + rendered = _render_template(hook.template, context).strip() + if rendered: + outcome.prompt_appends.append(rendered) + if dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: appended {len(rendered)} chars" + ) + elif dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: rendered empty prompt" + ) + continue + + assert hook.command is not None + command = _render_template(hook.command, context).strip() + if not command: + error_text = f"Hook {event}/{hook.name} failed: rendered command is empty" + if dev_mode: + outcome.debug_lines.append(f"[dev] {error_text}") + if hook.on_failure == "abort": + outcome.aborted = True + outcome.abort_reason = error_text + break + continue + + command_result = await _run_bash_hook( + command, + run_cwd, + hook.timeout_seconds, + env=hook_env, + stdin_payload=hook_stdin_json if hook.pass_stdin_json else None, + ) + if command_result.timed_out: + error_text = ( + f"Hook {event}/{hook.name} timed out after {hook.timeout_seconds}s: {command}" + ) + if dev_mode: + outcome.debug_lines.append(f"[dev] {error_text}") + if hook.on_failure == "abort": + outcome.aborted = True + outcome.abort_reason = error_text + break + continue + + if command_result.return_code != 0: + error_text = ( + f"Hook {event}/{hook.name} failed " + f"(exit {command_result.return_code}): {command}" + ) + if dev_mode: + stderr_preview = _truncate_debug(command_result.stderr) + stdout_preview = _truncate_debug(command_result.stdout) + if stderr_preview: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: stderr={stderr_preview}" + ) + if stdout_preview: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: stdout={stdout_preview}" + ) + outcome.debug_lines.append(f"[dev] {error_text}") + if hook.on_failure == "abort": + outcome.aborted = True + outcome.abort_reason = error_text + break + continue + + if dev_mode: + stdout_preview = _truncate_debug(command_result.stdout) + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: bash ok exit=0" + ) + if stdout_preview: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: stdout={stdout_preview}" + ) + + if hook.parse_stdout_decision: + blocked, reason, parse_error = _parse_hook_decision(command_result.stdout) + if parse_error and dev_mode: + outcome.debug_lines.append( + f"[dev] hooks {event}/{hook.name}: invalid decision JSON ({parse_error})" + ) + if blocked: + outcome.aborted = True + decision_reason = reason or "no reason provided" + outcome.abort_reason = ( + f"Hook {event}/{hook.name} blocked execution: {decision_reason}" + ) + break + + if dev_mode: + outcome.debug_lines.append(f"[dev] hooks {event}: matched={matched}") + + return outcome diff --git a/eagent/mcp/__init__.py b/eagent/mcp/__init__.py new file mode 100644 index 0000000..6f9c24c --- /dev/null +++ b/eagent/mcp/__init__.py @@ -0,0 +1,26 @@ +"""MCP subsystem.""" + +from eagent.mcp.client import McpClient, create_mcp_client +from eagent.mcp.config import get_mcp_config_paths, load_mcp_config, resolve_mcp_command +from eagent.mcp.manager import ( + get_active_mcp_server_count, + get_active_mcp_server_names, + initialize_mcp_servers, + shutdown_mcp_servers, +) +from eagent.mcp.types import McpServerConfig, McpToolCallResult, McpToolDefinition + +__all__ = [ + "McpClient", + "create_mcp_client", + "McpServerConfig", + "McpToolDefinition", + "McpToolCallResult", + "load_mcp_config", + "resolve_mcp_command", + "get_mcp_config_paths", + "initialize_mcp_servers", + "shutdown_mcp_servers", + "get_active_mcp_server_count", + "get_active_mcp_server_names", +] diff --git a/eagent/mcp/client.py b/eagent/mcp/client.py new file mode 100644 index 0000000..88f154d --- /dev/null +++ b/eagent/mcp/client.py @@ -0,0 +1,283 @@ +"""MCP stdio JSON-RPC client.""" + +from __future__ import annotations + +import asyncio +import json +import os +from contextlib import suppress +from typing import Any + +from eagent.mcp.types import ( + JsonRpcError, + JsonRpcResponse, + McpServerConfig, + McpToolDefinition, +) + +MCP_PROTOCOL_VERSION = "2024-11-05" +MAX_RECONNECT_RETRIES = 3 +REQUEST_TIMEOUT_SECONDS = 60 +INIT_TIMEOUT_SECONDS = 30 + + +class McpClient: + def __init__( + self, + server_name: str, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + ) -> None: + self.server_name = server_name + self.command = command + self.args = args or [] + self.env = env or {} + + self.process: asyncio.subprocess.Process | None = None + self._reader_task: asyncio.Task[None] | None = None + self._stderr_task: asyncio.Task[None] | None = None + self._next_id = 1 + self._connected = False + self._reconnect_count = 0 + self._pending: dict[int, asyncio.Future[JsonRpcResponse]] = {} + + @property + def is_connected(self) -> bool: + return self._connected and self.process is not None and self.process.returncode is None + + async def connect(self) -> None: + if self.is_connected: + return + await self._spawn_process() + await self._initialize() + self._connected = True + self._reconnect_count = 0 + + async def list_tools(self) -> list[McpToolDefinition]: + await self._ensure_connected() + response = await self._send_request("tools/list", {}) + if response.error: + raise RuntimeError( + f'MCP tools/list error from "{self.server_name}": {response.error.message}' + ) + result = response.result if isinstance(response.result, dict) else {} + tools_raw = result.get("tools", []) if isinstance(result, dict) else [] + tools: list[McpToolDefinition] = [] + if isinstance(tools_raw, list): + for tool in tools_raw: + if not isinstance(tool, dict): + continue + tools.append( + McpToolDefinition( + name=str(tool.get("name", "")), + description=( + str(tool.get("description")) + if tool.get("description") is not None + else None + ), + inputSchema=( + tool.get("inputSchema") + if isinstance(tool.get("inputSchema"), dict) + else None + ), + ) + ) + return tools + + async def call_tool(self, name: str, args: dict[str, Any]) -> dict[str, Any]: + await self._ensure_connected() + response = await self._send_request("tools/call", {"name": name, "arguments": args}) + if response.error: + return { + "content": [{"type": "text", "text": response.error.message}], + "isError": True, + } + result = response.result if isinstance(response.result, dict) else None + if not result: + return {"content": [{"type": "text", "text": "(empty result)"}], "isError": False} + return result + + async def disconnect(self) -> None: + self._connected = False + + for req_id, future in list(self._pending.items()): + if not future.done(): + future.set_exception(RuntimeError("MCP client disconnecting")) + self._pending.pop(req_id, None) + + if self._reader_task: + self._reader_task.cancel() + with suppress(asyncio.CancelledError): + await self._reader_task + self._reader_task = None + + if self._stderr_task: + self._stderr_task.cancel() + with suppress(asyncio.CancelledError): + await self._stderr_task + self._stderr_task = None + + if self.process: + if self.process.stdin is not None: + self.process.stdin.close() + if self.process.returncode is None: + self.process.terminate() + try: + await asyncio.wait_for(self.process.wait(), timeout=2) + except TimeoutError: + self.process.kill() + await self.process.wait() + self.process = None + + async def _spawn_process(self) -> None: + env = {**os.environ, **self.env} + self.process = await asyncio.create_subprocess_exec( + self.command, + *self.args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + self._reader_task = asyncio.create_task(self._reader_loop()) + self._stderr_task = asyncio.create_task(self._stderr_loop()) + + async def _stderr_loop(self) -> None: + if self.process is None or self.process.stderr is None: + return + try: + while True: + line = await self.process.stderr.readline() + if not line: + return + text = line.decode("utf-8", errors="replace").rstrip() + if text: + print(f"[mcp:{self.server_name}] {text}") + except asyncio.CancelledError: + return + + async def _reader_loop(self) -> None: + if self.process is None or self.process.stdout is None: + return + + while True: + line = await self.process.stdout.readline() + if not line: + await self._handle_process_exit() + return + + payload = line.decode("utf-8", errors="replace").strip() + if not payload: + continue + + try: + message = json.loads(payload) + except json.JSONDecodeError: + continue + + msg_id = message.get("id") + if isinstance(msg_id, int) and msg_id in self._pending: + future = self._pending.pop(msg_id) + if future.done(): + continue + + error = message.get("error") + response = JsonRpcResponse( + jsonrpc=str(message.get("jsonrpc", "2.0")), + id=msg_id, + result=message.get("result"), + error=( + JsonRpcError(**error) + if isinstance(error, dict) and "message" in error + else None + ), + ) + future.set_result(response) + + async def _initialize(self) -> None: + response = await self._send_request( + "initialize", + { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": {"name": "RTE-AI", "version": "0.1.0"}, + }, + timeout=INIT_TIMEOUT_SECONDS, + ) + if response.error: + raise RuntimeError( + f'MCP initialize failed for "{self.server_name}": {response.error.message}' + ) + await self._send_notification("notifications/initialized", {}) + + async def _send_request( + self, + method: str, + params: dict[str, Any], + timeout: int = REQUEST_TIMEOUT_SECONDS, + ) -> JsonRpcResponse: + if self.process is None or self.process.stdin is None: + raise RuntimeError(f'MCP server "{self.server_name}" stdin not available') + + req_id = self._next_id + self._next_id += 1 + + request = { + "jsonrpc": "2.0", + "id": req_id, + "method": method, + "params": params, + } + + loop = asyncio.get_running_loop() + future: asyncio.Future[JsonRpcResponse] = loop.create_future() + self._pending[req_id] = future + + self.process.stdin.write((json.dumps(request, ensure_ascii=False) + "\n").encode("utf-8")) + await self.process.stdin.drain() + + try: + return await asyncio.wait_for(future, timeout=timeout) + except TimeoutError as exc: + self._pending.pop(req_id, None) + raise RuntimeError( + f'MCP request "{method}" to "{self.server_name}" timed out after {timeout}s' + ) from exc + + async def _send_notification(self, method: str, params: dict[str, Any]) -> None: + if self.process is None or self.process.stdin is None: + return + request = { + "jsonrpc": "2.0", + "method": method, + "params": params, + } + self.process.stdin.write((json.dumps(request, ensure_ascii=False) + "\n").encode("utf-8")) + await self.process.stdin.drain() + + async def _ensure_connected(self) -> None: + if self.is_connected: + return + await self._attempt_reconnect() + + async def _attempt_reconnect(self) -> None: + if self._reconnect_count >= MAX_RECONNECT_RETRIES: + raise RuntimeError(f'MCP server "{self.server_name}" reconnect retries exhausted') + self._reconnect_count += 1 + await self.disconnect() + await self.connect() + + async def _handle_process_exit(self) -> None: + self._connected = False + if self.process is not None: + await self.process.wait() + + for req_id, future in list(self._pending.items()): + if not future.done(): + future.set_exception(RuntimeError(f'MCP server "{self.server_name}" exited')) + self._pending.pop(req_id, None) + + +def create_mcp_client(server_name: str, config: McpServerConfig) -> McpClient: + return McpClient(server_name, config.command, config.args, config.env) diff --git a/eagent/mcp/config.py b/eagent/mcp/config.py new file mode 100644 index 0000000..b41bba2 --- /dev/null +++ b/eagent/mcp/config.py @@ -0,0 +1,75 @@ +"""MCP configuration loading from settings files.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from eagent.mcp.types import McpServerConfig +from eagent.paths import env_root + +PROJECT_CONFIG_DIR = ".agents" +SETTINGS_FILE = "settings.json" + + +def _read_settings(path: Path) -> dict[str, Any] | None: + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return None + + +def _extract_servers(settings: dict[str, Any] | None) -> dict[str, McpServerConfig]: + if not settings: + return {} + mcp = settings.get("mcpServers") + if not isinstance(mcp, dict): + return {} + + out: dict[str, McpServerConfig] = {} + for name, cfg in mcp.items(): + if not isinstance(cfg, dict): + continue + command = cfg.get("command") + if not isinstance(command, str) or not command.strip(): + continue + args = cfg.get("args") + env = cfg.get("env") + out[name] = McpServerConfig( + command=command.strip(), + args=[str(a) for a in args] if isinstance(args, list) else [], + env={str(k): str(v) for k, v in env.items()} if isinstance(env, dict) else None, + ) + return out + + +def resolve_mcp_command(config: McpServerConfig) -> McpServerConfig: + command = config.command.strip() + if command.startswith("~"): + command = str(Path(command).expanduser()) + return McpServerConfig(command=command, args=list(config.args), env=dict(config.env or {})) + + +async def load_mcp_config(cwd: str) -> dict[str, McpServerConfig]: + root = Path(cwd).resolve() + + user_servers: dict[str, McpServerConfig] = {} + project_servers: dict[str, McpServerConfig] = {} + + user_settings = _read_settings(env_root() / SETTINGS_FILE) + user_servers.update(_extract_servers(user_settings)) + + project_settings = _read_settings(root / PROJECT_CONFIG_DIR / SETTINGS_FILE) + project_servers.update(_extract_servers(project_settings)) + + merged = {**user_servers, **project_servers} + return {name: resolve_mcp_command(cfg) for name, cfg in merged.items()} + + +def get_mcp_config_paths(cwd: str) -> dict[str, str]: + root = Path(cwd).resolve() + return { + "project": str(root / PROJECT_CONFIG_DIR / SETTINGS_FILE), + "user": str(env_root() / SETTINGS_FILE), + } diff --git a/eagent/mcp/manager.py b/eagent/mcp/manager.py new file mode 100644 index 0000000..de2f120 --- /dev/null +++ b/eagent/mcp/manager.py @@ -0,0 +1,68 @@ +"""MCP manager for initializing servers and wrapping tools.""" + +from __future__ import annotations + +import asyncio + +from eagent.core.types import Tool +from eagent.mcp.client import McpClient, create_mcp_client +from eagent.mcp.config import load_mcp_config +from eagent.tools.mcp_wrapper import wrap_mcp_tool + +_active_clients: list[McpClient] = [] + + +async def initialize_mcp_servers(cwd: str) -> list[Tool]: + configs = await load_mcp_config(cwd) + if not configs: + return [] + + async def _connect(name: str, config) -> tuple[str, list[Tool]]: + client = create_mcp_client(name, config) + try: + await client.connect() + _active_clients.append(client) + tool_defs = await client.list_tools() + wrapped = [ + wrap_mcp_tool( + server_name=name, + tool_name=tool_def.name, + description=tool_def.description or f"MCP tool: {tool_def.name}", + input_schema=tool_def.inputSchema or {"type": "object"}, + client=client, + ) + for tool_def in tool_defs + ] + return name, wrapped + except Exception as exc: + print(f"[mcp] Failed to connect to server '{name}': {exc}") + try: + await client.disconnect() + except Exception: + pass + return name, [] + + tasks = [_connect(name, cfg) for name, cfg in configs.items()] + results = await asyncio.gather(*tasks) + + tools: list[Tool] = [] + for name, server_tools in results: + if server_tools: + print(f"[mcp] Server '{name}': {len(server_tools)} tool(s) registered") + tools.extend(server_tools) + return tools + + +async def shutdown_mcp_servers() -> None: + await asyncio.gather( + *(client.disconnect() for client in list(_active_clients)), return_exceptions=True + ) + _active_clients.clear() + + +def get_active_mcp_server_count() -> int: + return sum(1 for client in _active_clients if client.is_connected) + + +def get_active_mcp_server_names() -> list[str]: + return [client.server_name for client in _active_clients if client.is_connected] diff --git a/eagent/mcp/types.py b/eagent/mcp/types.py new file mode 100644 index 0000000..e4bb713 --- /dev/null +++ b/eagent/mcp/types.py @@ -0,0 +1,49 @@ +"""Types for MCP subsystem.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class McpServerConfig: + command: str + args: list[str] = field(default_factory=list) + env: dict[str, str] | None = None + + +@dataclass +class JsonRpcRequest: + jsonrpc: str + id: int + method: str + params: dict[str, Any] + + +@dataclass +class JsonRpcError: + code: int + message: str + data: Any | None = None + + +@dataclass +class JsonRpcResponse: + jsonrpc: str + id: int | None = None + result: Any | None = None + error: JsonRpcError | None = None + + +@dataclass +class McpToolDefinition: + name: str + description: str | None = None + inputSchema: dict[str, Any] | None = None + + +@dataclass +class McpToolCallResult: + content: list[dict[str, Any]] + isError: bool = False diff --git a/eagent/paths.py b/eagent/paths.py new file mode 100644 index 0000000..d32b808 --- /dev/null +++ b/eagent/paths.py @@ -0,0 +1,21 @@ +"""Shared EnvAgent filesystem locations.""" + +from __future__ import annotations + +import os +from pathlib import Path + + +def env_root() -> Path: + configured = os.getenv("ENV_ROOT") + if configured: + return Path(configured).expanduser().resolve() + return (Path.home() / ".env").resolve() + + +def agents_root() -> Path: + return (Path.home() / ".agents").resolve() + + +def project_agents_root(cwd: str) -> Path: + return Path(cwd).expanduser().resolve() / ".agents" diff --git a/eagent/permissions/__init__.py b/eagent/permissions/__init__.py new file mode 100644 index 0000000..4c8b0b8 --- /dev/null +++ b/eagent/permissions/__init__.py @@ -0,0 +1,30 @@ +"""Permission utilities.""" + +from eagent.permissions.engine import ( + PermissionContext, + add_session_rule, + check_permission, + clear_session_rules, + get_session_rules, + is_read_only_command, +) +from eagent.permissions.modes import ModeRestrictions, get_mode_description, get_mode_restrictions +from eagent.permissions.path_validation import PathValidationResult, validate_path +from eagent.permissions.rules import load_project_rules, load_user_rules, match_rule + +__all__ = [ + "PermissionContext", + "add_session_rule", + "get_session_rules", + "clear_session_rules", + "is_read_only_command", + "check_permission", + "ModeRestrictions", + "get_mode_description", + "get_mode_restrictions", + "PathValidationResult", + "validate_path", + "load_project_rules", + "load_user_rules", + "match_rule", +] diff --git a/eagent/permissions/engine.py b/eagent/permissions/engine.py new file mode 100644 index 0000000..552ca34 --- /dev/null +++ b/eagent/permissions/engine.py @@ -0,0 +1,194 @@ +"""Central permission decision engine.""" + +from __future__ import annotations + +import shlex +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from eagent.core.types import PermissionDecision, PermissionRule +from eagent.permissions.modes import get_mode_restrictions +from eagent.permissions.rules import load_project_rules, load_user_rules, match_rule + +_session_rules: list[PermissionRule] = [] + +READ_ONLY_PREFIXES = ( + "ls", + "cat", + "head", + "tail", + "wc", + "grep", + "rg", + "find", + "fd", + "pwd", + "echo", + "date", + "git log", + "git status", + "git diff", + "git show", +) + + +def add_session_rule(rule: PermissionRule) -> None: + _session_rules.append(rule) + + +def get_session_rules() -> list[PermissionRule]: + return list(_session_rules) + + +def clear_session_rules() -> None: + _session_rules.clear() + + +def is_read_only_command(command: str) -> bool: + cmd = command.strip() + return any( + cmd == prefix or cmd.startswith(prefix + " ") or cmd.startswith(prefix + "\t") + for prefix in READ_ONLY_PREFIXES + ) + + +@dataclass(frozen=True) +class PermissionContext: + cwd: str + permission_mode: str + tools: list[Any] + + +def _resolve_path(raw_path: str, cwd: str) -> Path: + path = Path(raw_path).expanduser() + if not path.is_absolute(): + path = Path(cwd) / path + return path.resolve(strict=False) + + +def _is_within(root: Path, path: Path) -> bool: + try: + path.relative_to(root) + return True + except ValueError: + return False + + +def _collect_candidate_paths(input_data: dict[str, Any], cwd: str) -> list[Path]: + path_keys = ("file_path", "path", "filePath", "filename", "directory") + list_keys = ("paths", "files") + paths: list[Path] = [] + + for key in path_keys: + value = input_data.get(key) + if isinstance(value, str) and value.strip(): + paths.append(_resolve_path(value, cwd)) + + for key in list_keys: + value = input_data.get(key) + if not isinstance(value, list): + continue + for item in value: + if isinstance(item, str) and item.strip(): + paths.append(_resolve_path(item, cwd)) + + return paths + + +def _collect_bash_command_paths(input_data: dict[str, Any], cwd: str) -> list[Path]: + command = input_data.get("command") or input_data.get("cmd") + if not isinstance(command, str) or not command.strip(): + return [] + + try: + tokens = shlex.split(command, posix=True) + except ValueError: + return [] + + candidates: list[Path] = [] + for token in tokens: + trimmed = token.strip() + if not trimmed: + continue + cleaned = trimmed.lstrip("><").rstrip(",;") + if not cleaned: + continue + if cleaned.startswith("-"): + continue + if cleaned.startswith(("http://", "https://")): + continue + if cleaned.startswith(("/", "~/", "./", "../")) or "/" in cleaned: + candidates.append(_resolve_path(cleaned, cwd)) + + return candidates + + +async def check_permission( + tool_name: str, input_data: dict[str, Any], context: PermissionContext +) -> PermissionDecision: + project_rules = await load_project_rules(context.cwd) + user_rules = await load_user_rules() + all_rules = [*_session_rules, *project_rules, *user_rules] + + for rule in all_rules: + if rule.behavior == "deny" and match_rule(tool_name, input_data, rule): + return PermissionDecision( + behavior="deny", message=f"Denied by {rule.source} rule: {rule.tool}" + ) + + if context.permission_mode == "bypassPermissions": + return PermissionDecision(behavior="allow") + + for rule in all_rules: + if rule.behavior == "allow" and match_rule(tool_name, input_data, rule): + return PermissionDecision(behavior="allow") + + tool_def = next((t for t in context.tools if getattr(t, "name", None) == tool_name), None) + restrictions = get_mode_restrictions(context.permission_mode) # type: ignore[arg-type] + is_read_only_tool = False + if tool_def is not None: + try: + is_read_only_tool = bool(tool_def.is_read_only(input_data)) + except Exception: + is_read_only_tool = False + + if context.permission_mode == "plan": + if restrictions.allow_writes or is_read_only_tool: + return PermissionDecision(behavior="allow") + return PermissionDecision( + behavior="deny", message="Write operations are not allowed in plan mode." + ) + + cwd_root = Path(context.cwd).expanduser().resolve(strict=False) + candidate_paths = _collect_candidate_paths(input_data, context.cwd) + if tool_name.lower() == "bash": + candidate_paths.extend(_collect_bash_command_paths(input_data, context.cwd)) + + if candidate_paths: + outside_paths = [path for path in candidate_paths if not _is_within(cwd_root, path)] + if not outside_paths: + return PermissionDecision(behavior="allow") + outside_path = str(outside_paths[0]) + return PermissionDecision( + behavior="ask", + message=( + f'Path "{outside_path}" is outside current directory "{cwd_root}". ' + "Please choose Allow or Deny for this request." + ), + ) + + if tool_name.lower() == "bash": + return PermissionDecision(behavior="allow") + + if ( + context.permission_mode == "acceptEdits" + and restrictions.allow_writes + and tool_name in {"Edit", "Write", "NotebookEdit"} + ): + return PermissionDecision(behavior="allow") + + if is_read_only_tool: + return PermissionDecision(behavior="allow") + + return PermissionDecision(behavior="ask", message=f'Tool "{tool_name}" requires permission.') diff --git a/eagent/permissions/modes.py b/eagent/permissions/modes.py new file mode 100644 index 0000000..8f9e065 --- /dev/null +++ b/eagent/permissions/modes.py @@ -0,0 +1,37 @@ +"""Permission mode metadata.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from eagent.core.types import PermissionMode + +MODE_DESCRIPTIONS: dict[PermissionMode, str] = { + "default": "Default mode: write operations and mutating shell commands require confirmation.", + "plan": "Plan mode: read-only mode with no write operations.", + "acceptEdits": "Accept-edits mode: file edits are allowed, shell mutations still require confirmation.", + "bypassPermissions": "Bypass mode: all actions are auto-approved.", +} + + +@dataclass(frozen=True) +class ModeRestrictions: + allow_reads: bool + allow_writes: bool + allow_bash: bool + + +MODE_RESTRICTIONS: dict[PermissionMode, ModeRestrictions] = { + "default": ModeRestrictions(allow_reads=True, allow_writes=False, allow_bash=False), + "plan": ModeRestrictions(allow_reads=True, allow_writes=False, allow_bash=False), + "acceptEdits": ModeRestrictions(allow_reads=True, allow_writes=True, allow_bash=False), + "bypassPermissions": ModeRestrictions(allow_reads=True, allow_writes=True, allow_bash=True), +} + + +def get_mode_description(mode: PermissionMode) -> str: + return MODE_DESCRIPTIONS.get(mode, f"Unknown mode: {mode}") + + +def get_mode_restrictions(mode: PermissionMode) -> ModeRestrictions: + return MODE_RESTRICTIONS.get(mode, MODE_RESTRICTIONS["default"]) diff --git a/eagent/permissions/path_validation.py b/eagent/permissions/path_validation.py new file mode 100644 index 0000000..e4ed9af --- /dev/null +++ b/eagent/permissions/path_validation.py @@ -0,0 +1,82 @@ +"""Path safety validation.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +DANGEROUS_ABS = ( + "/", + "/etc", + "/usr", + "/bin", + "/sbin", + "/var", + "/boot", + "/dev", + "/proc", + "/sys", +) + + +def _dangerous_home_paths() -> tuple[str, ...]: + home = str(Path.home()) + return ( + os.path.join(home, ".ssh"), + os.path.join(home, ".aws"), + os.path.join(home, ".gnupg"), + os.path.join(home, ".config"), + ) + + +def _dangerous_match(path: str) -> str | None: + for prefix in DANGEROUS_ABS: + if path == prefix or (prefix != "/" and path.startswith(prefix + os.sep)): + return prefix + for prefix in _dangerous_home_paths(): + if path == prefix or path.startswith(prefix + os.sep): + return prefix + return None + + +def _resolve_symlink(path: Path) -> Path: + try: + return path.resolve(strict=True) + except Exception: + try: + return path.parent.resolve(strict=True) / path.name + except Exception: + return path.resolve(strict=False) + + +@dataclass(frozen=True) +class PathValidationResult: + allowed: bool + message: str | None = None + + +async def validate_path(file_path: str, project_root: str) -> PathValidationResult: + resolved = str(Path(file_path).expanduser().resolve(strict=False)) + danger = _dangerous_match(resolved) + if danger: + return PathValidationResult( + False, f'Access denied: "{resolved}" is within dangerous path "{danger}".' + ) + + real = str(_resolve_symlink(Path(resolved))) + if real != resolved: + danger2 = _dangerous_match(real) + if danger2: + return PathValidationResult( + False, + f'Access denied: "{resolved}" resolves to dangerous path "{real}" under "{danger2}".', + ) + + root = str(Path(project_root).expanduser().resolve(strict=False)) + if real != root and not real.startswith(root + os.sep): + return PathValidationResult( + True, f'Warning: "{resolved}" is outside project root "{root}".' + ) + + return PathValidationResult(True) diff --git a/eagent/permissions/rules.py b/eagent/permissions/rules.py new file mode 100644 index 0000000..7dc4790 --- /dev/null +++ b/eagent/permissions/rules.py @@ -0,0 +1,129 @@ +"""Permission rule loading and matching.""" + +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Any + +from eagent.core.types import PermissionRule +from eagent.paths import env_root + +PROJECT_CONFIG_DIR = ".agents" + + +def _read_settings(path: Path) -> dict[str, Any] | None: + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return None + + +def _extract_rules(settings: dict[str, Any] | None, source: str) -> list[PermissionRule]: + if not settings: + return [] + permissions = settings.get("permissions") + if not isinstance(permissions, dict): + return [] + + out: list[PermissionRule] = [] + for behavior in ("allow", "deny"): + entries = permissions.get(behavior) + if not isinstance(entries, list): + continue + for entry in entries: + if not isinstance(entry, dict): + continue + tool = entry.get("tool") + if not isinstance(tool, str): + continue + content = entry.get("content") + out.append( + PermissionRule( + tool=tool, + behavior=behavior, # type: ignore[arg-type] + source=source, # type: ignore[arg-type] + content=content if isinstance(content, str) else None, + ) + ) + return out + + +async def load_project_rules(cwd: str) -> list[PermissionRule]: + root = Path(cwd).resolve() + rules: list[PermissionRule] = [] + rules.extend(_extract_rules(_read_settings(root / PROJECT_CONFIG_DIR / "settings.json"), "project")) + return rules + + +async def load_user_rules() -> list[PermissionRule]: + rules: list[PermissionRule] = [] + rules.extend(_extract_rules(_read_settings(env_root() / "settings.json"), "user")) + return rules + + +def _glob_to_regex(pattern: str) -> re.Pattern[str]: + esc = "" + i = 0 + while i < len(pattern): + ch = pattern[i] + if ch == "*" and i + 1 < len(pattern) and pattern[i + 1] == "*": + esc += ".*" + i += 2 + elif ch == "*": + esc += "[^/]*" + i += 1 + elif ch == "?": + esc += "[^/]" + i += 1 + else: + esc += re.escape(ch) + i += 1 + return re.compile(f"^{esc}$") + + +def _glob_match(pattern: str, value: str) -> bool: + try: + return bool(_glob_to_regex(pattern).match(value)) + except re.error: + return False + + +def _match_tool(tool_name: str, rule_pattern: str) -> bool: + if tool_name == rule_pattern: + return True + if rule_pattern.startswith("mcp__") and rule_pattern.endswith("__"): + return tool_name.startswith(rule_pattern) + if "*" in rule_pattern or "?" in rule_pattern: + return _glob_match(rule_pattern, tool_name) + return False + + +def _match_content(tool_name: str, input_data: dict[str, Any], pattern: str) -> bool: + name = tool_name.lower() + if name == "bash": + command = input_data.get("command") or input_data.get("cmd") or "" + return isinstance(command, str) and _glob_match(pattern, command) + + file_path = ( + input_data.get("file_path") + or input_data.get("path") + or input_data.get("filePath") + or input_data.get("filename") + ) + if isinstance(file_path, str) and file_path: + return _glob_match(pattern, file_path) + + for value in input_data.values(): + if isinstance(value, str) and _glob_match(pattern, value): + return True + return False + + +def match_rule(tool_name: str, input_data: dict[str, Any], rule: PermissionRule) -> bool: + if not _match_tool(tool_name, rule.tool): + return False + if not rule.content: + return True + return _match_content(tool_name, input_data, rule.content) diff --git a/eagent/prompt/__init__.py b/eagent/prompt/__init__.py new file mode 100644 index 0000000..5b96f10 --- /dev/null +++ b/eagent/prompt/__init__.py @@ -0,0 +1 @@ +"""Module package.""" diff --git a/eagent/prompt/agent_prompts.py b/eagent/prompt/agent_prompts.py new file mode 100644 index 0000000..20d02a9 --- /dev/null +++ b/eagent/prompt/agent_prompts.py @@ -0,0 +1,17 @@ +"""Sub-agent prompts.""" + +DEFAULT_AGENT_PROMPT = "You are a sub-agent. Complete the delegated task thoroughly." +EXPLORE_AGENT_PROMPT = "You are a read-only exploration sub-agent." +PLAN_AGENT_PROMPT = "You are a planning sub-agent. Do not modify files." + +AGENT_PROMPTS = { + "default": DEFAULT_AGENT_PROMPT, + "explore": EXPLORE_AGENT_PROMPT, + "plan": PLAN_AGENT_PROMPT, +} + + +def get_agent_prompt(agent_type: str | None = None) -> str: + if not agent_type: + return DEFAULT_AGENT_PROMPT + return AGENT_PROMPTS.get(agent_type.lower(), DEFAULT_AGENT_PROMPT) diff --git a/eagent/prompt/cache_boundary.py b/eagent/prompt/cache_boundary.py new file mode 100644 index 0000000..cc5a24a --- /dev/null +++ b/eagent/prompt/cache_boundary.py @@ -0,0 +1,34 @@ +"""Prompt cache boundary helpers.""" + +from __future__ import annotations + +from eagent.core.types import SystemPromptBlock + +SYSTEM_PROMPT_DYNAMIC_BOUNDARY = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__" + + +def split_system_blocks( + blocks: list[SystemPromptBlock], +) -> tuple[list[SystemPromptBlock], list[SystemPromptBlock]]: + index = next((i for i, b in enumerate(blocks) if b.text == SYSTEM_PROMPT_DYNAMIC_BOUNDARY), -1) + if index < 0: + return list(blocks), [] + return blocks[:index], blocks[index + 1 :] + + +def apply_cache(blocks: list[SystemPromptBlock]) -> list[SystemPromptBlock]: + static_blocks, dynamic_blocks = split_system_blocks(blocks) + if not static_blocks: + return dynamic_blocks + + out = [ + SystemPromptBlock(type="text", text=b.text, cache_control=b.cache_control) + for b in static_blocks + ] + out[-1].cache_control = {"type": "ephemeral"} + out.extend(dynamic_blocks) + return out + + +def strip_boundary(blocks: list[SystemPromptBlock]) -> list[SystemPromptBlock]: + return [b for b in blocks if b.text != SYSTEM_PROMPT_DYNAMIC_BOUNDARY] diff --git a/eagent/prompt/compact_prompt.py b/eagent/prompt/compact_prompt.py new file mode 100644 index 0000000..6836826 --- /dev/null +++ b/eagent/prompt/compact_prompt.py @@ -0,0 +1,51 @@ +"""Conversation compaction prompts.""" + +from __future__ import annotations + +COMPACT_BOUNDARY_MARKER = "[CONVERSATION_COMPACTED]" + +COMPACT_SYSTEM_INSTRUCTION = ( + "You are a conversation summarizer. Only summarize the conversation in a structured way." +) + +COMPACT_PROMPT = """Summarize the following conversation in detail. +Use sections: +1) Primary request and intent +2) Technical concepts +3) Files and code sections +4) Errors and fixes +5) Problem solving +6) All user messages +7) Pending tasks +8) Current work +9) Optional next step +Keep identifiers, paths, and commands concrete. +""" + + +def format_compact_summary(summary: str) -> str: + return ( + f"{COMPACT_BOUNDARY_MARKER}\n\n" + "The following is a summary of previous conversation. Continue from it.\n\n" + f"{summary}\n" + ) + + +def serialize_messages_for_compact(messages: list[dict[str, object]]) -> str: + lines: list[str] = [] + for msg in messages: + role = str(msg.get("role", "unknown")).upper() + lines.append(f"--- {role} ---") + content = msg.get("content", []) + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + btype = block.get("type") + if btype == "text": + lines.append(str(block.get("text", ""))) + elif btype == "tool_use": + lines.append(f"[Tool call] {block.get('name')} {block.get('input')}") + elif btype == "tool_result": + lines.append(f"[Tool result] {block.get('content')}") + lines.append("") + return "\n".join(lines) diff --git a/eagent/prompt/system_prompt.py b/eagent/prompt/system_prompt.py new file mode 100644 index 0000000..2426b1e --- /dev/null +++ b/eagent/prompt/system_prompt.py @@ -0,0 +1,55 @@ +"""System prompt builder.""" + +from __future__ import annotations + +import datetime as _dt +import os + +from eagent.core.types import SystemPromptBlock +from eagent.prompt.cache_boundary import SYSTEM_PROMPT_DYNAMIC_BOUNDARY + +_STATIC_PROMPT = """You are RTE-AI, a CLI coding assistant. +Be concise, safe, and action-oriented. +Use available tools to complete tasks. +Prefer dedicated tools over shell commands. +""" + + +def build_system_prompt_blocks( + agent_memory: str, git_context: str, cwd: str, model: str +) -> list[SystemPromptBlock]: + blocks: list[SystemPromptBlock] = [ + SystemPromptBlock(type="text", text=_STATIC_PROMPT.strip()), + SystemPromptBlock(type="text", text=SYSTEM_PROMPT_DYNAMIC_BOUNDARY), + ] + + if agent_memory.strip(): + blocks.append( + SystemPromptBlock( + type="text", + text=f"## Memory\n\n{agent_memory.strip()}\n", + ) + ) + + blocks.append( + SystemPromptBlock( + type="text", + text=( + "## Environment\n" + f"- Working directory: {cwd}\n" + f"- Model: {model}\n" + f"- Platform: {os.name}\n" + f"- Date: {_dt.date.today().isoformat()}" + ), + ) + ) + + if git_context.strip(): + blocks.append( + SystemPromptBlock( + type="text", + text=f"## Git Status\n\n{git_context.strip()}\n", + ) + ) + + return blocks diff --git a/eagent/reload.py b/eagent/reload.py new file mode 100644 index 0000000..50a565d --- /dev/null +++ b/eagent/reload.py @@ -0,0 +1,27 @@ +"""Reload argument handling for Env-integrated and standalone agent entrypoints.""" + +from __future__ import annotations + +import json +import os +import sys + +_RELOAD_ARGV_ENV = "EAGENT_RELOAD_ARGV" + + +class ReloadArgs: + @staticmethod + def remember(argv: list[str]) -> None: + os.environ[_RELOAD_ARGV_ENV] = json.dumps(argv) + + @staticmethod + def current() -> list[str]: + raw = os.environ.get(_RELOAD_ARGV_ENV) + if raw: + try: + value = json.loads(raw) + except Exception: + value = None + if isinstance(value, list) and value and all(isinstance(item, str) for item in value): + return value + return list(sys.argv) diff --git a/eagent/skills/__init__.py b/eagent/skills/__init__.py new file mode 100644 index 0000000..11ae315 --- /dev/null +++ b/eagent/skills/__init__.py @@ -0,0 +1,32 @@ +"""Skills subsystem.""" + +from eagent.skills.loader import ( + load_all_skills, + load_skills_from_dir, + parse_frontmatter, + parse_skill_file, +) +from eagent.skills.skill_tool import ( + build_skill_tool, + format_skill_listing, + get_loaded_skills, + initialize_skills, + reset_skills, + set_skill_query_params, +) +from eagent.skills.types import SkillDefinition, SkillLoadResult + +__all__ = [ + "SkillDefinition", + "SkillLoadResult", + "parse_frontmatter", + "parse_skill_file", + "load_skills_from_dir", + "load_all_skills", + "initialize_skills", + "get_loaded_skills", + "reset_skills", + "set_skill_query_params", + "format_skill_listing", + "build_skill_tool", +] diff --git a/eagent/skills/loader.py b/eagent/skills/loader.py new file mode 100644 index 0000000..2cb36c4 --- /dev/null +++ b/eagent/skills/loader.py @@ -0,0 +1,267 @@ +"""Skill loader and frontmatter parser.""" + +from __future__ import annotations + +import fnmatch +import re +from pathlib import Path +from typing import Any + +from eagent.paths import agents_root, env_root +from eagent.skills.types import SkillDefinition, SkillLoadResult, SkillSource + +SKILL_FILE = "SKILL.md" + + +def _parse_simple_yaml(yaml_text: str) -> dict[str, Any]: + result: dict[str, Any] = {} + lines = yaml_text.splitlines() + current_key: str | None = None + current_array: list[str] | None = None + + for raw in lines: + line = raw.rstrip() + if not line.strip() or line.lstrip().startswith("#"): + continue + + array_match = re.match(r"^\s*-\s+(.+)\s*$", line) + if array_match and current_key: + if current_array is None: + current_array = [] + current_array.append(_unquote(array_match.group(1).strip())) + result[current_key] = current_array + continue + + if current_array is not None and current_key: + result[current_key] = current_array + current_array = None + + kv = re.match(r"^([a-zA-Z0-9_-]+)\s*:\s*(.*)$", line) + if not kv: + continue + + current_key = kv.group(1) + value = kv.group(2).strip() + + if not value: + current_array = [] + continue + + if value.startswith("[") and value.endswith("]"): + inner = value[1:-1] + result[current_key] = [ + _unquote(part.strip()) for part in inner.split(",") if part.strip() + ] + current_array = None + continue + + lower = value.lower() + if lower in {"true", "yes"}: + result[current_key] = True + current_array = None + continue + if lower in {"false", "no"}: + result[current_key] = False + current_array = None + continue + if re.fullmatch(r"-?\d+(\.\d+)?", value): + result[current_key] = float(value) if "." in value else int(value) + current_array = None + continue + + result[current_key] = _unquote(value) + current_array = None + + if current_array is not None and current_key: + result[current_key] = current_array + + return result + + +def _unquote(value: str) -> str: + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + return value[1:-1] + return value + + +def parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: + stripped = content.lstrip() + if not stripped.startswith("---"): + return {}, content + + rest = stripped[3:] + idx = rest.find("\n---") + if idx < 0: + return {}, content + + yaml = rest[:idx].strip() + body = rest[idx + 4 :].lstrip("\n") + return _parse_simple_yaml(yaml), body + + +def _normalize_string_list(value: Any) -> list[str] | None: + if isinstance(value, str): + return [value] + if isinstance(value, list): + items = [str(v) for v in value if isinstance(v, (str, int, float))] + return items if items else None + return None + + +def _split_args(args: str) -> list[str]: + if not args.strip(): + return [] + + values: list[str] = [] + current: list[str] = [] + in_quote = False + quote_char = "" + i = 0 + while i < len(args): + ch = args[i] + if ch in {"'", '"'}: + if in_quote and ch == quote_char: + in_quote = False + quote_char = "" + i += 1 + continue + if not in_quote: + in_quote = True + quote_char = ch + i += 1 + continue + if not in_quote and ch.isspace(): + if current: + values.append("".join(current)) + current = [] + i += 1 + continue + current.append(ch) + i += 1 + if current: + values.append("".join(current)) + return values + + +def _escape_regex(text: str) -> str: + return re.escape(text) + + +def _make_expander(template: str, skill_root: str, arg_names: list[str] | None): + async def _expand(args: str) -> str: + result = template + + for key in ( + "${ENV_AGENT_SKILL_DIR}", + "$ENV_AGENT_SKILL_DIR", + ): + result = result.replace(key, skill_root) + + result = result.replace("$ARGUMENTS", args) + positional = _split_args(args) + + if arg_names: + for idx, name in enumerate(arg_names): + value = positional[idx] if idx < len(positional) else "" + result = re.sub(rf"\\${_escape_regex(name)}\\b", value, result) + + for i in range(1, 10): + value = positional[i - 1] if i - 1 < len(positional) else "" + result = re.sub(rf"\\${i}\\b", value, result) + + return result + + return _expand + + +def _matches_paths(skill: SkillDefinition, cwd: Path) -> bool: + if not skill.paths: + return True + for pattern in skill.paths: + if any(fnmatch.fnmatch(str(path.relative_to(cwd)), pattern) for path in cwd.rglob("*")): + return True + return False + + +async def parse_skill_file(file_path: str) -> SkillDefinition | None: + path = Path(file_path) + try: + content = path.read_text(encoding="utf-8") + except Exception: + return None + if not content.strip(): + return None + + frontmatter, body = parse_frontmatter(content) + + skill_root = str(path.parent.resolve()) + name = str(frontmatter.get("name") or path.parent.name) + description = str(frontmatter.get("description") or f"Skill: {name}") + argument_names = _normalize_string_list(frontmatter.get("arguments")) + paths = _normalize_string_list(frontmatter.get("paths")) + allowed_tools = _normalize_string_list(frontmatter.get("allowed-tools")) + + skill = SkillDefinition( + name=name, + description=description, + when_to_use=str(frontmatter.get("when_to_use")) if frontmatter.get("when_to_use") else None, + argument_hint=( + str(frontmatter.get("argument-hint")) if frontmatter.get("argument-hint") else None + ), + argument_names=argument_names, + allowed_tools=allowed_tools, + model=str(frontmatter.get("model")) if frontmatter.get("model") else None, + user_invocable=bool(frontmatter.get("user-invocable", True)), + context="fork" if str(frontmatter.get("context") or "inline") == "fork" else "inline", + agent=str(frontmatter.get("agent")) if frontmatter.get("agent") else None, + paths=paths, + skill_root=skill_root, + get_prompt=_make_expander(body, skill_root, argument_names), + ) + + return skill + + +async def load_skills_from_dir(skills_dir: str, source: SkillSource) -> list[SkillLoadResult]: + root = Path(skills_dir) + if not root.exists() or not root.is_dir(): + return [] + + results: list[SkillLoadResult] = [] + for child in sorted(root.iterdir()): + if not child.is_dir(): + continue + skill_file = child / SKILL_FILE + if not skill_file.exists(): + continue + skill = await parse_skill_file(str(skill_file)) + if skill is None: + continue + results.append(SkillLoadResult(skill=skill, source=source, file_path=str(skill_file))) + return results + + +async def load_all_skills(cwd: str) -> list[SkillDefinition]: + cwd_root = Path(cwd).resolve() + seen: dict[str, SkillLoadResult] = {} + + search_dirs: list[tuple[Path, SkillSource]] = [ + (cwd_root / ".agents" / "skills", "project"), + (agents_root() / "skills", "user"), + (env_root() / "skills", "user"), + ] + + for skills_dir, source in search_dirs: + results = await load_skills_from_dir(str(skills_dir), source) + for result in results: + key = result.skill.name.lower() + if key not in seen: + seen[key] = result + + out: list[SkillDefinition] = [] + for item in seen.values(): + if _matches_paths(item.skill, cwd_root): + out.append(item.skill) + return out diff --git a/eagent/skills/skill_tool.py b/eagent/skills/skill_tool.py new file mode 100644 index 0000000..45e6b52 --- /dev/null +++ b/eagent/skills/skill_tool.py @@ -0,0 +1,142 @@ +"""Skill tool integration.""" + +from __future__ import annotations + +from typing import Any + +from eagent.core.types import QueryParams, Tool, ToolContext, ToolResult +from eagent.skills.loader import load_all_skills +from eagent.skills.types import SkillDefinition + +_loaded_skills: list[SkillDefinition] = [] +_initialized = False +_query_params: QueryParams | None = None + + +def set_skill_query_params(params: QueryParams | None) -> None: + global _query_params + _query_params = params + + +async def initialize_skills(cwd: str) -> None: + global _loaded_skills, _initialized + _loaded_skills = await load_all_skills(cwd) + _initialized = True + + +def get_loaded_skills() -> list[SkillDefinition]: + return list(_loaded_skills) + + +def reset_skills() -> None: + global _loaded_skills, _initialized + _loaded_skills = [] + _initialized = False + + +def format_skill_listing(skills: list[SkillDefinition]) -> str: + if not skills: + return "" + lines = ["Available skills (invoke via Skill tool):", ""] + for skill in skills: + lines.append(f" - {skill.name}: {skill.description}") + if skill.when_to_use: + lines.append(f" When to use: {skill.when_to_use}") + if skill.argument_hint: + lines.append(f" Arguments: {skill.argument_hint}") + if skill.context == "fork": + lines.append(" Execution: forked sub-agent") + return "\n".join(lines) + + +def _find_skill(name: str) -> SkillDefinition | None: + normalized = name[1:] if name.startswith("/") else name + lookup = normalized.lower() + for skill in _loaded_skills: + if skill.name.lower() == lookup: + return skill + return None + + +def _skill_not_found(requested: str) -> str: + if not _loaded_skills: + return ( + f'Skill "{requested}" not found. No skills are loaded. ' + "Place skills under .agents/skills, ~/.agents/skills, or ~/.env/skills." + ) + lines = [f'Skill "{requested}" not found. Available skills:'] + for skill in _loaded_skills: + lines.append(f" - {skill.name}: {skill.description}") + return "\n".join(lines) + + +async def _execute_inline(skill: SkillDefinition, args: str) -> ToolResult: + assert skill.get_prompt is not None + prompt = await skill.get_prompt(args) + return ToolResult(result=prompt) + + +async def _execute_fork(skill: SkillDefinition, args: str, context: ToolContext) -> ToolResult: + if _query_params is None: + return await _execute_inline(skill, args) + + from eagent.core.agent_loop import run_sub_agent + + assert skill.get_prompt is not None + prompt = await skill.get_prompt(args) + + try: + result = await run_sub_agent( + prompt, + _query_params, + tools=skill.allowed_tools, + max_turns=50, + model=skill.model, + ) + return ToolResult(result=result or "(Skill produced no output)") + except Exception as exc: + return ToolResult(result=f"Skill execution failed: {exc}", is_error=True) + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + if not _initialized: + await initialize_skills(context.cwd) + + skill_name = str(input_data.get("skill") or "").strip() + args = str(input_data.get("args") or "") + + if not skill_name: + return ToolResult(result="Error: skill name is required.", is_error=True) + + skill = _find_skill(skill_name) + if skill is None: + return ToolResult(result=_skill_not_found(skill_name), is_error=True) + + if skill.context == "fork": + return await _execute_fork(skill, args, context) + return await _execute_inline(skill, args) + + +def build_skill_tool() -> Tool: + return Tool( + name="Skill", + description=( + "Invoke a loaded skill by name. Skills are reusable prompt templates from " + ".agents/skills, ~/.agents/skills, or ~/.env/skills." + ), + input_schema={ + "type": "object", + "properties": { + "skill": {"type": "string"}, + "args": {"type": "string"}, + }, + "required": ["skill"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: format_skill_listing(_loaded_skills), + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=120_000, + user_facing_name=lambda input_data: f"Skill: {input_data.get('skill')}", + ) diff --git a/eagent/skills/types.py b/eagent/skills/types.py new file mode 100644 index 0000000..36c76bc --- /dev/null +++ b/eagent/skills/types.py @@ -0,0 +1,34 @@ +"""Types for skills subsystem.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class SkillDefinition: + name: str + description: str + when_to_use: str | None = None + argument_hint: str | None = None + argument_names: list[str] | None = None + allowed_tools: list[str] | None = None + model: str | None = None + user_invocable: bool = True + context: Literal["inline", "fork"] = "inline" + agent: str | None = None + paths: list[str] | None = None + skill_root: str | None = None + get_prompt: Callable[[str], Awaitable[str]] | None = None + + +SkillSource = Literal["user", "project"] + + +@dataclass +class SkillLoadResult: + skill: SkillDefinition + source: SkillSource + file_path: str diff --git a/eagent/tools/__init__.py b/eagent/tools/__init__.py new file mode 100644 index 0000000..74ce4da --- /dev/null +++ b/eagent/tools/__init__.py @@ -0,0 +1,62 @@ +"""Built-in tools and registry.""" + +from eagent.tools.agent_tool import build_agent_tool, set_agent_query_params +from eagent.tools.ask import build_ask_tool +from eagent.tools.bash import build_bash_tool +from eagent.tools.bash_readonly import is_read_only_command, parse_command_parts +from eagent.tools.edit import build_edit_tool +from eagent.tools.glob_tool import build_glob_tool +from eagent.tools.grep_tool import build_grep_tool +from eagent.tools.notebook_edit import build_notebook_edit_tool +from eagent.tools.plan_mode import build_enter_plan_mode_tool, build_exit_plan_mode_tool +from eagent.tools.read import build_read_tool +from eagent.tools.registry import ( + exclude_tools_by_name, + filter_tools_by_name, + generate_tool_summary, + get_all_tools, + get_read_only_tool_names, + get_tool_by_name, + get_tool_count, + has_tool_by_name, + initialize_tools, + register_dynamic_tools, + register_tool, + reset_registry, +) +from eagent.tools.todo import build_todo_tool +from eagent.tools.web_fetch import build_web_fetch_tool +from eagent.tools.web_search import build_web_search_tool +from eagent.tools.write import build_write_tool + +__all__ = [ + "build_bash_tool", + "build_read_tool", + "build_edit_tool", + "build_write_tool", + "build_glob_tool", + "build_grep_tool", + "build_agent_tool", + "set_agent_query_params", + "build_ask_tool", + "build_todo_tool", + "build_web_fetch_tool", + "build_web_search_tool", + "build_enter_plan_mode_tool", + "build_exit_plan_mode_tool", + "build_notebook_edit_tool", + "is_read_only_command", + "parse_command_parts", + "register_tool", + "register_dynamic_tools", + "get_all_tools", + "get_tool_by_name", + "has_tool_by_name", + "get_tool_count", + "initialize_tools", + "reset_registry", + "get_read_only_tool_names", + "filter_tools_by_name", + "exclude_tools_by_name", + "generate_tool_summary", +] diff --git a/eagent/tools/agent_tool.py b/eagent/tools/agent_tool.py new file mode 100644 index 0000000..36fb8b3 --- /dev/null +++ b/eagent/tools/agent_tool.py @@ -0,0 +1,67 @@ +"""Agent delegation tool.""" + +from __future__ import annotations + +from typing import Any + +from eagent.core.types import QueryParams, Tool, ToolContext, ToolResult + +_AGENT_PARAMS: QueryParams | None = None + + +def set_agent_query_params(params: QueryParams | None) -> None: + global _AGENT_PARAMS + _AGENT_PARAMS = params + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + prompt = str(input_data.get("prompt") or "").strip() + if not prompt: + return ToolResult(result="Error: prompt is required.", is_error=True) + if _AGENT_PARAMS is None: + return ToolResult(result="Error: agent context not initialized.", is_error=True) + + tools = input_data.get("tools") + disallowed_tools = input_data.get("disallowed_tools") + max_turns = input_data.get("max_turns") + model = input_data.get("model") + + from eagent.core.agent_loop import run_sub_agent + + try: + text = await run_sub_agent( + prompt, + _AGENT_PARAMS, + tools=tools if isinstance(tools, list) else None, + disallowed_tools=disallowed_tools if isinstance(disallowed_tools, list) else None, + max_turns=int(max_turns) if isinstance(max_turns, int) else None, + model=str(model) if isinstance(model, str) else None, + ) + return ToolResult(result=text) + except Exception as exc: + return ToolResult(result=f"Sub-agent failed: {exc}", is_error=True) + + +def build_agent_tool() -> Tool: + return Tool( + name="Agent", + description="Run a delegated sub-task in an isolated sub-agent.", + input_schema={ + "type": "object", + "properties": { + "prompt": {"type": "string"}, + "tools": {"type": "array", "items": {"type": "string"}}, + "disallowed_tools": {"type": "array", "items": {"type": "string"}}, + "max_turns": {"type": "integer", "minimum": 1}, + "model": {"type": "string"}, + }, + "required": ["prompt"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Delegate bounded sub-tasks when parallel or isolated reasoning is beneficial.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=100_000, + user_facing_name=lambda _i: "Agent", + ) diff --git a/eagent/tools/ask.py b/eagent/tools/ask.py new file mode 100644 index 0000000..6709c25 --- /dev/null +++ b/eagent/tools/ask.py @@ -0,0 +1,36 @@ +"""Ask tool for requesting clarification from user.""" + +from __future__ import annotations + +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + _ = context + question = str(input_data.get("question") or "").strip() + if not question: + return ToolResult(result="Error: question is required.", is_error=True) + return ToolResult(result=f"[ASK_USER] {question}") + + +def build_ask_tool() -> Tool: + return Tool( + name="Ask", + description="Request clarification from user.", + input_schema={ + "type": "object", + "properties": { + "question": {"type": "string"}, + }, + "required": ["question"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use Ask when required input is missing and assumptions are risky.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=10_000, + user_facing_name=lambda _i: "Ask", + ) diff --git a/eagent/tools/bash.py b/eagent/tools/bash.py new file mode 100644 index 0000000..43f93ad --- /dev/null +++ b/eagent/tools/bash.py @@ -0,0 +1,138 @@ +"""Bash tool.""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +from asyncio.subprocess import PIPE +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult +from eagent.tools.bash_readonly import is_read_only_command + +DEFAULT_TIMEOUT_SECONDS = 120 +MAX_TIMEOUT_SECONDS = 600 +MAX_RESULT_SIZE_CHARS = 30_000 +MAX_BUFFER_SIZE = 10 * 1024 * 1024 + + +async def _execute(command: str, cwd: str, timeout_s: int) -> tuple[str, str, int, bool]: + process = await asyncio.create_subprocess_exec( + "bash", + "-c", + command, + cwd=cwd, + stdout=PIPE, + stderr=PIPE, + env={ + **os.environ, + "LANG": os.environ.get("LANG", "en_US.UTF-8"), + "TERM": os.environ.get("TERM", "xterm-256color"), + "GIT_PAGER": "cat", + "PAGER": "cat", + }, + ) + + timed_out = False + try: + stdout_b, stderr_b = await asyncio.wait_for(process.communicate(), timeout=timeout_s) + except asyncio.CancelledError: + with contextlib.suppress(ProcessLookupError): + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=2) + except TimeoutError: + with contextlib.suppress(ProcessLookupError): + process.kill() + await process.wait() + raise + except TimeoutError: + timed_out = True + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=2) + except TimeoutError: + process.kill() + await process.wait() + stdout_b, stderr_b = b"", b"[command timed out]" + + stdout = stdout_b.decode("utf-8", errors="replace")[:MAX_BUFFER_SIZE] + stderr = stderr_b.decode("utf-8", errors="replace")[:MAX_BUFFER_SIZE] + code = process.returncode or 0 + return stdout, stderr, code, timed_out + + +def _format_output(stdout: str, stderr: str, code: int, timed_out: bool) -> str: + parts: list[str] = [] + if timed_out: + parts.append("[Command timed out]") + + if stdout: + if len(stdout) > MAX_RESULT_SIZE_CHARS: + parts.append(stdout[:MAX_RESULT_SIZE_CHARS]) + parts.append(f"\n[stdout truncated: {len(stdout)} chars total]") + else: + parts.append(stdout) + + if stderr and stderr.strip(): + cap = MAX_RESULT_SIZE_CHARS // 3 + if len(stderr) > cap: + parts.append(f"\nSTDERR:\n{stderr[:cap]}") + parts.append(f"[stderr truncated: {len(stderr)} chars total]") + else: + parts.append(f"\nSTDERR:\n{stderr}") + + if not parts: + return "(No output)" if code == 0 else f"(No output, exit code: {code})" + + if code != 0 and not timed_out: + parts.append(f"\n(exit code: {code})") + return "".join(parts) + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + command = str(input_data.get("command") or "").strip() + if not command: + return ToolResult(result="Error: command cannot be empty.", is_error=True) + + timeout = int(input_data.get("timeout") or DEFAULT_TIMEOUT_SECONDS) + timeout = max(1, min(timeout, MAX_TIMEOUT_SECONDS)) + + try: + stdout, stderr, code, timed_out = await _execute(command, context.cwd, timeout) + result = _format_output(stdout, stderr, code, timed_out) + return ToolResult(result=result, is_error=(code != 0 and not timed_out)) + except Exception as exc: + return ToolResult(result=f"Error executing command: {exc}", is_error=True) + + +def build_bash_tool() -> Tool: + return Tool( + name="Bash", + description=( + "Execute a bash command. Use for running scripts, searching code, " + "checking file status, and tests." + ), + input_schema={ + "type": "object", + "properties": { + "command": {"type": "string"}, + "timeout": {"type": "integer", "minimum": 1, "maximum": MAX_TIMEOUT_SECONDS}, + "description": {"type": "string"}, + }, + "required": ["command"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: ( + "Execute bash commands in the working directory. " + "Prefer non-interactive commands and include timeout for long operations." + ), + is_read_only=lambda input_data: is_read_only_command(str(input_data.get("command") or "")), + is_concurrency_safe=lambda input_data: is_read_only_command( + str(input_data.get("command") or "") + ), + max_result_size_chars=MAX_RESULT_SIZE_CHARS, + user_facing_name=lambda input_data: f"Bash: {str(input_data.get('command') or '')[:60]}", + ) diff --git a/eagent/tools/bash_readonly.py b/eagent/tools/bash_readonly.py new file mode 100644 index 0000000..8d96e31 --- /dev/null +++ b/eagent/tools/bash_readonly.py @@ -0,0 +1,281 @@ +"""Read-only shell command analyzer used by Bash tool and permissions.""" + +from __future__ import annotations + +import re +import shlex + +SAFE_COMMANDS = { + "cat", + "head", + "tail", + "less", + "more", + "wc", + "sort", + "uniq", + "diff", + "comm", + "find", + "ls", + "tree", + "pwd", + "echo", + "printf", + "grep", + "egrep", + "fgrep", + "rg", + "ag", + "ack", + "awk", + "sed", + "tr", + "cut", + "paste", + "column", + "fold", + "fmt", + "expand", + "unexpand", + "tee", + "git", + "which", + "type", + "file", + "stat", + "du", + "df", + "env", + "printenv", + "date", + "uname", + "whoami", + "id", + "hostname", + "test", + "true", + "false", + "npm", + "yarn", + "pnpm", + "pip", + "pip3", + "realpath", + "dirname", + "basename", + "readlink", + "md5sum", + "sha256sum", + "sha1sum", + "shasum", + "xxd", + "od", + "strings", + "nm", + "hexdump", + "jq", + "xargs", + "python", + "python3", + "node", + "go", + "rustc", + "cargo", + "java", + "javac", + "gcc", + "clang", +} + +GIT_SAFE_SUBCOMMANDS = { + "log", + "diff", + "show", + "status", + "branch", + "remote", + "tag", + "rev-parse", + "rev-list", + "describe", + "shortlog", + "blame", + "ls-files", + "ls-tree", + "ls-remote", + "cat-file", + "name-rev", + "config", + "for-each-ref", + "count-objects", + "stash", +} + +SAFE_NPM_SUBCOMMANDS = { + "list", + "ls", + "view", + "info", + "show", + "search", + "outdated", + "explain", + "why", + "fund", + "audit", + "doctor", + "config", +} +SAFE_YARN_SUBCOMMANDS = {"list", "info", "why", "outdated", "config"} +SAFE_PIP_SUBCOMMANDS = {"list", "show", "freeze", "check"} + +DANGEROUS_PATTERNS = ( + re.compile(r"\$\("), + re.compile(r"`[^`]*`"), + re.compile(r"<\("), + re.compile(r">\("), + re.compile(r"(?(?!&\d|/dev/null)"), + re.compile(r">>(?!/dev/null)"), + re.compile(r"\beval\b"), + re.compile(r"\bexec\b"), + re.compile(r"\bsource\b"), +) + + +def parse_command_parts(command: str) -> list[str]: + parts: list[str] = [] + current: list[str] = [] + in_single = False + in_double = False + i = 0 + + while i < len(command): + ch = command[i] + if ch == "\\" and not in_single and i + 1 < len(command): + current.append(ch) + current.append(command[i + 1]) + i += 2 + continue + + if ch == "'" and not in_double: + in_single = not in_single + current.append(ch) + i += 1 + continue + + if ch == '"' and not in_single: + in_double = not in_double + current.append(ch) + i += 1 + continue + + if not in_single and not in_double: + if command.startswith("&&", i) or command.startswith("||", i): + if "".join(current).strip(): + parts.append("".join(current).strip()) + current = [] + i += 2 + continue + if ch in {";", "|"}: + if "".join(current).strip(): + parts.append("".join(current).strip()) + current = [] + i += 1 + continue + + current.append(ch) + i += 1 + + tail = "".join(current).strip() + if tail: + parts.append(tail) + return parts + + +def _extract_command_and_args(part: str) -> tuple[str, list[str]]: + try: + tokens = shlex.split(part, posix=True) + except ValueError: + return "", [] + if not tokens: + return "", [] + + idx = 0 + while idx < len(tokens) and "=" in tokens[idx] and not tokens[idx].startswith(("./", "/")): + left, _right = tokens[idx].split("=", 1) + if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", left): + idx += 1 + continue + break + + if idx >= len(tokens): + return "", [] + + return tokens[idx], tokens[idx + 1 :] + + +def _is_safe_sed(args: list[str]) -> bool: + for arg in args: + if arg in {"-i", "--in-place"}: + return False + if arg.startswith("-i") and len(arg) > 2: + return False + return True + + +def _is_safe_git(args: list[str]) -> bool: + if not args: + return False + sub = args[0] + if sub == "stash": + if len(args) == 1: + return False + return args[1] in {"list", "show"} + return sub in GIT_SAFE_SUBCOMMANDS + + +def _is_safe_pkg(cmd: str, args: list[str]) -> bool: + if not args: + return False + sub = args[0] + if cmd in {"npm", "pnpm"}: + return sub in SAFE_NPM_SUBCOMMANDS + if cmd == "yarn": + return sub in SAFE_YARN_SUBCOMMANDS + if cmd in {"pip", "pip3"}: + return sub in SAFE_PIP_SUBCOMMANDS + return True + + +def _is_readonly_part(part: str) -> bool: + cmd, args = _extract_command_and_args(part) + if not cmd: + return False + + if cmd not in SAFE_COMMANDS: + return False + + if cmd == "git": + return _is_safe_git(args) + if cmd in {"npm", "pnpm", "yarn", "pip", "pip3"}: + return _is_safe_pkg(cmd, args) + if cmd == "sed": + return _is_safe_sed(args) + if cmd == "tee": + return any(a == "/dev/null" for a in args) or not args + return True + + +def is_read_only_command(command: str) -> bool: + stripped = command.strip() + if not stripped: + return False + for pattern in DANGEROUS_PATTERNS: + if pattern.search(stripped): + return False + + parts = parse_command_parts(stripped) + if not parts: + return False + return all(_is_readonly_part(part) for part in parts) diff --git a/eagent/tools/edit.py b/eagent/tools/edit.py new file mode 100644 index 0000000..1896d51 --- /dev/null +++ b/eagent/tools/edit.py @@ -0,0 +1,167 @@ +"""Edit tool for targeted find/replace edits.""" + +from __future__ import annotations + +import difflib +from pathlib import Path +from typing import Any + +from eagent.core.types import FileState, Tool, ToolContext, ToolResult +from eagent.files.atomic_write import atomic_write + + +def _count_occurrences(content: str, needle: str) -> int: + if not needle: + return 0 + return content.count(needle) + + +def _diff_preview(old: str, new: str, context: int = 3) -> str: + diff = difflib.unified_diff( + old.splitlines(), + new.splitlines(), + fromfile="before", + tofile="after", + lineterm="", + n=context, + ) + text = "\n".join(diff) + return text[:4000] + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + raw_path = str( + input_data.get("file_path") or input_data.get("path") or input_data.get("filePath") or "" + ) + if not raw_path: + return ToolResult(result="Error: file_path parameter is required.", is_error=True) + + old_string = str(input_data.get("old_string") or "") + new_string = str(input_data.get("new_string") or "") + replace_all = bool(input_data.get("replace_all") or False) + + file_path = Path(raw_path) + if not file_path.is_absolute(): + file_path = Path(context.cwd) / file_path + file_path = file_path.resolve() + + if old_string == new_string: + return ToolResult(result="Error: old_string and new_string are identical.", is_error=True) + + exists = file_path.exists() + + if old_string == "" and not exists: + atomic_write(str(file_path), new_string) + context.modified_files.add(str(file_path)) + context.file_history.tracked_files.add(str(file_path)) + context.read_file_state.set( + str(file_path), + FileState(content=new_string, timestamp=file_path.stat().st_mtime * 1000), + ) + return ToolResult( + result=f"Created new file: {file_path} ({len(new_string.splitlines())} lines)" + ) + + if not exists: + return ToolResult( + result=( + f"Error: file not found: {file_path}. To create a new file, set old_string to empty string." + ), + is_error=True, + ) + + cached_state = context.read_file_state.get(str(file_path)) + if cached_state is None: + return ToolResult( + result=f"Error: you must Read the file before editing it. Use Read on {file_path} first.", + is_error=True, + ) + + if cached_state.is_partial_view and old_string not in cached_state.content: + return ToolResult( + result=( + "Error: the file was only partially read and the target text was not in the cached segment. " + "Read full file or relevant section first." + ), + is_error=True, + ) + + current_mtime = file_path.stat().st_mtime * 1000 + if current_mtime > cached_state.timestamp + 1000: + return ToolResult( + result="Error: file changed after last read. Read file again before editing.", + is_error=True, + ) + + try: + content = file_path.read_text(encoding="utf-8") + except Exception as exc: + return ToolResult(result=f"Error reading file: {exc}", is_error=True) + + if old_string not in content: + trimmed = old_string.strip() + if trimmed and trimmed in content: + return ToolResult( + result=( + "Error: exact match not found, but a whitespace-trimmed match exists. " + "Ensure old_string matches exact whitespace and line breaks." + ), + is_error=True, + ) + return ToolResult(result=f"Error: old_string not found in {file_path}", is_error=True) + + occurrences = _count_occurrences(content, old_string) + if occurrences > 1 and not replace_all: + return ToolResult( + result=( + f"Error: old_string appears {occurrences} times. Provide more specific context or set replace_all=true." + ), + is_error=True, + ) + + new_content = ( + content.replace(old_string, new_string) + if replace_all + else content.replace(old_string, new_string, 1) + ) + if new_content == content: + return ToolResult(result="No changes applied.") + + try: + atomic_write(str(file_path), new_content) + except Exception as exc: + return ToolResult(result=f"Error writing file: {exc}", is_error=True) + + context.modified_files.add(str(file_path)) + context.file_history.tracked_files.add(str(file_path)) + context.read_file_state.set( + str(file_path), + FileState(content=new_content, timestamp=file_path.stat().st_mtime * 1000), + ) + + preview = _diff_preview(content, new_content) + return ToolResult(result=f"Edited: {file_path}\n\n{preview}") + + +def build_edit_tool() -> Tool: + return Tool( + name="Edit", + description="Apply targeted text replacement in a file.", + input_schema={ + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "old_string": {"type": "string"}, + "new_string": {"type": "string"}, + "replace_all": {"type": "boolean"}, + }, + "required": ["file_path", "old_string", "new_string"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use Edit for small, precise changes. Read file first.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=30_000, + user_facing_name=lambda input_data: f"Edit: {input_data.get('file_path') or input_data.get('path')}", + ) diff --git a/eagent/tools/glob_tool.py b/eagent/tools/glob_tool.py new file mode 100644 index 0000000..895cf79 --- /dev/null +++ b/eagent/tools/glob_tool.py @@ -0,0 +1,58 @@ +"""Glob tool for file discovery.""" + +from __future__ import annotations + +import glob +from pathlib import Path +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult + +MAX_RESULTS = 1000 + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + pattern = str(input_data.get("pattern") or "") + if not pattern: + return ToolResult(result="Error: pattern parameter is required.", is_error=True) + + base = str(input_data.get("path") or context.cwd) + base_path = Path(base) + if not base_path.is_absolute(): + base_path = Path(context.cwd) / base_path + base_path = base_path.resolve() + + full_pattern = str(base_path / pattern) + matches = sorted(glob.glob(full_pattern, recursive=True)) + matches = [str(Path(m).resolve()) for m in matches][:MAX_RESULTS] + + if not matches: + return ToolResult(result="No files matched.") + + rel = [ + str(Path(m).relative_to(context.cwd)) if str(Path(m)).startswith(context.cwd) else m + for m in matches + ] + return ToolResult(result="\n".join(rel)) + + +def build_glob_tool() -> Tool: + return Tool( + name="Glob", + description="Find files by glob pattern.", + input_schema={ + "type": "object", + "properties": { + "pattern": {"type": "string"}, + "path": {"type": "string"}, + }, + "required": ["pattern"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use Glob to discover files before reading/editing.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: True, + max_result_size_chars=60_000, + user_facing_name=lambda input_data: f"Glob: {input_data.get('pattern')}", + ) diff --git a/eagent/tools/grep_tool.py b/eagent/tools/grep_tool.py new file mode 100644 index 0000000..e7e8d64 --- /dev/null +++ b/eagent/tools/grep_tool.py @@ -0,0 +1,84 @@ +"""Grep-like regex search tool.""" + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult + +MAX_MATCH_LINES = 2000 + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + pattern = str(input_data.get("pattern") or "") + if not pattern: + return ToolResult(result="Error: pattern parameter is required.", is_error=True) + + base = str(input_data.get("path") or context.cwd) + include = str(input_data.get("include") or "") + case_sensitive = bool(input_data.get("case_sensitive") or False) + + root = Path(base) + if not root.is_absolute(): + root = Path(context.cwd) / root + root = root.resolve() + + flags = 0 if case_sensitive else re.IGNORECASE + try: + regex = re.compile(pattern, flags) + except re.error as exc: + return ToolResult(result=f"Error: invalid regex: {exc}", is_error=True) + + files: list[Path] = [] + if include: + files = [p for p in root.rglob(include) if p.is_file()] + else: + files = [p for p in root.rglob("*") if p.is_file()] + + matches: list[str] = [] + for file_path in files: + try: + text = file_path.read_text(encoding="utf-8", errors="replace") + except Exception: + continue + for idx, line in enumerate(text.splitlines(), start=1): + if regex.search(line): + path_display = ( + str(file_path.relative_to(context.cwd)) + if str(file_path).startswith(context.cwd) + else str(file_path) + ) + matches.append(f"{path_display}:{idx}:{line}") + if len(matches) >= MAX_MATCH_LINES: + matches.append(f"... truncated at {MAX_MATCH_LINES} matches") + return ToolResult(result="\n".join(matches)) + + if not matches: + return ToolResult(result="No matches found.") + return ToolResult(result="\n".join(matches)) + + +def build_grep_tool() -> Tool: + return Tool( + name="Grep", + description="Search text by regex across files.", + input_schema={ + "type": "object", + "properties": { + "pattern": {"type": "string"}, + "path": {"type": "string"}, + "include": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, + "required": ["pattern"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use Grep for content search with regex patterns.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: True, + max_result_size_chars=100_000, + user_facing_name=lambda input_data: f"Grep: {input_data.get('pattern')}", + ) diff --git a/eagent/tools/mcp_wrapper.py b/eagent/tools/mcp_wrapper.py new file mode 100644 index 0000000..a1a74bf --- /dev/null +++ b/eagent/tools/mcp_wrapper.py @@ -0,0 +1,55 @@ +"""Helper for wrapping MCP tools as eagent tools.""" + +from __future__ import annotations + +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult + + +class McpToolClientProtocol: + async def call_tool(self, name: str, args: dict[str, Any]) -> dict[str, Any]: ... + + +def wrap_mcp_tool( + server_name: str, tool_name: str, description: str, input_schema: dict[str, Any], client: Any +) -> Tool: + async def _call(input_data: dict[str, Any], _context: ToolContext) -> ToolResult: + try: + result = await client.call_tool(tool_name, input_data) + except Exception as exc: + return ToolResult(result=f"MCP tool error: {exc}", is_error=True) + + content = result.get("content", "") if isinstance(result, dict) else "" + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(str(block.get("text", ""))) + else: + text_parts.append(str(block)) + else: + text_parts.append(str(block)) + text = "\n".join(text_parts) + else: + text = str(content) + + is_error = ( + bool(result.get("isError") or result.get("is_error")) + if isinstance(result, dict) + else False + ) + return ToolResult(result=text or "(empty result)", is_error=is_error) + + return Tool( + name=f"mcp__{server_name}__{tool_name}", + description=description or f"MCP tool: {tool_name}", + input_schema=input_schema or {"type": "object"}, + call=_call, + prompt=lambda: description or f"MCP tool from {server_name}", + is_concurrency_safe=lambda _i: True, + is_read_only=lambda _i: False, + max_result_size_chars=200_000, + user_facing_name=lambda _i: f"{server_name}:{tool_name}", + ) diff --git a/eagent/tools/notebook_edit.py b/eagent/tools/notebook_edit.py new file mode 100644 index 0000000..15e58fd --- /dev/null +++ b/eagent/tools/notebook_edit.py @@ -0,0 +1,77 @@ +"""NotebookEdit tool for minimal Jupyter cell updates.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult +from eagent.files.atomic_write import atomic_write + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + raw_path = str(input_data.get("file_path") or "") + cell_index = int(input_data.get("cell_index") or -1) + new_source = input_data.get("new_source") + + if not raw_path: + return ToolResult(result="Error: file_path is required.", is_error=True) + if cell_index < 0: + return ToolResult(result="Error: cell_index must be >= 0.", is_error=True) + if new_source is None: + return ToolResult(result="Error: new_source is required.", is_error=True) + + path = Path(raw_path) + if not path.is_absolute(): + path = Path(context.cwd) / path + path = path.resolve() + + try: + notebook = json.loads(path.read_text(encoding="utf-8")) + except Exception as exc: + return ToolResult(result=f"Error reading notebook: {exc}", is_error=True) + + cells = notebook.get("cells") + if not isinstance(cells, list): + return ToolResult(result="Error: invalid notebook format (missing cells).", is_error=True) + if cell_index >= len(cells): + return ToolResult( + result=f"Error: cell_index out of range (0..{len(cells)-1}).", is_error=True + ) + + source_text = str(new_source) + cells[cell_index]["source"] = source_text.splitlines(keepends=True) + + try: + atomic_write(str(path), json.dumps(notebook, ensure_ascii=False, indent=2) + "\n") + except Exception as exc: + return ToolResult(result=f"Error writing notebook: {exc}", is_error=True) + + context.modified_files.add(str(path)) + context.file_history.tracked_files.add(str(path)) + + return ToolResult(result=f"Updated notebook cell {cell_index} in {path}.") + + +def build_notebook_edit_tool() -> Tool: + return Tool( + name="NotebookEdit", + description="Edit a Jupyter notebook cell by index.", + input_schema={ + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "cell_index": {"type": "integer", "minimum": 0}, + "new_source": {"type": "string"}, + }, + "required": ["file_path", "cell_index", "new_source"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use NotebookEdit for focused notebook cell changes.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=10_000, + user_facing_name=lambda input_data: f"NotebookEdit: {input_data.get('file_path')}", + ) diff --git a/eagent/tools/plan_mode.py b/eagent/tools/plan_mode.py new file mode 100644 index 0000000..d9b2927 --- /dev/null +++ b/eagent/tools/plan_mode.py @@ -0,0 +1,48 @@ +"""Plan mode helper tools.""" + +from __future__ import annotations + +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult + + +async def _enter_call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + del input_data + context.permission_mode = "plan" + return ToolResult(result="Plan mode enabled (read-only).") + + +async def _exit_call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + del input_data + if context.permission_mode == "plan": + context.permission_mode = "default" + return ToolResult(result=f"Permission mode: {context.permission_mode}") + + +def build_enter_plan_mode_tool() -> Tool: + return Tool( + name="EnterPlanMode", + description="Switch to read-only plan mode.", + input_schema={"type": "object", "properties": {}, "additionalProperties": False}, + call=_enter_call, + prompt=lambda: "Switch to plan mode when only planning is required.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=2000, + user_facing_name=lambda _i: "EnterPlanMode", + ) + + +def build_exit_plan_mode_tool() -> Tool: + return Tool( + name="ExitPlanMode", + description="Return from plan mode to default mode.", + input_schema={"type": "object", "properties": {}, "additionalProperties": False}, + call=_exit_call, + prompt=lambda: "Exit plan mode once implementation can start.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=2000, + user_facing_name=lambda _i: "ExitPlanMode", + ) diff --git a/eagent/tools/read.py b/eagent/tools/read.py new file mode 100644 index 0000000..4b83a3a --- /dev/null +++ b/eagent/tools/read.py @@ -0,0 +1,130 @@ +"""Read tool.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from eagent.core.types import FileState, Tool, ToolContext, ToolResult +from eagent.files.utils import detect_encoding, format_with_line_numbers, is_binary_data + +DEFAULT_LIMIT = 2000 +MAX_RESULT_SIZE_CHARS = 60_000 +IMAGE_EXTENSIONS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".ico", + ".svg", + ".webp", + ".tiff", + ".tif", + ".psd", + ".raw", + ".heif", + ".heic", +} + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + raw_path = str(input_data.get("file_path") or input_data.get("path") or "") + if not raw_path: + return ToolResult(result="Error: file_path parameter is required.", is_error=True) + + file_path = Path(raw_path) + if not file_path.is_absolute(): + file_path = Path(context.cwd) / file_path + file_path = file_path.resolve() + + offset = int(input_data.get("offset") or 1) + limit = int(input_data.get("limit") or DEFAULT_LIMIT) + offset = max(1, offset) + limit = max(1, limit) + + if file_path.suffix.lower() in IMAGE_EXTENSIONS: + return ToolResult( + result=f"This is an image file ({file_path.suffix}). Use an image-capable viewer.", + is_error=False, + ) + + try: + data = file_path.read_bytes() + except FileNotFoundError: + return ToolResult(result=f"Error: file not found: {file_path}", is_error=True) + except IsADirectoryError: + return ToolResult( + result=f"Error: {file_path} is a directory. Use Glob or Bash ls/find to inspect it.", + is_error=True, + ) + except PermissionError: + return ToolResult(result=f"Error: permission denied: {file_path}", is_error=True) + except Exception as exc: + return ToolResult(result=f"Error reading file: {exc}", is_error=True) + + if is_binary_data(data): + return ToolResult(result=f"This is a binary file ({len(data)} bytes).", is_error=False) + + encoding = detect_encoding(data) + if encoding == "utf-16le": + content = data[2:].decode("utf-16le", errors="replace") + else: + content = data.decode("utf-8-sig", errors="replace") + + all_lines = content.split("\n") + total = len(all_lines) + start = offset - 1 + end = min(start + limit, total) + selected = all_lines[start:end] + + rendered = format_with_line_numbers("\n".join(selected), start_line=offset) + if start > 0 or end < total: + meta: list[str] = [] + if start > 0: + meta.append(f"(showing from line {offset})") + if end < total: + meta.append(f"({total - end} more lines below, {total} total)") + rendered += "\n" + " ".join(meta) + + context.read_file_state.set( + str(file_path), + FileState( + content=content, + timestamp=file_path.stat().st_mtime * 1000, + offset=offset, + limit=limit, + is_partial_view=start > 0 or end < total, + ), + ) + + if len(rendered) > MAX_RESULT_SIZE_CHARS: + rendered = ( + rendered[:MAX_RESULT_SIZE_CHARS] + + f"\n\n[Output truncated at {MAX_RESULT_SIZE_CHARS} chars]" + ) + + return ToolResult(result=rendered) + + +def build_read_tool() -> Tool: + return Tool( + name="Read", + description="Read a file with line numbers.", + input_schema={ + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "offset": {"type": "integer", "minimum": 1}, + "limit": {"type": "integer", "minimum": 1}, + }, + "required": ["file_path"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Read files before editing; use offset/limit for large files.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: True, + max_result_size_chars=MAX_RESULT_SIZE_CHARS, + user_facing_name=lambda input_data: f"Read: {input_data.get('file_path') or input_data.get('path')}", + ) diff --git a/eagent/tools/registry.py b/eagent/tools/registry.py new file mode 100644 index 0000000..fdd18ad --- /dev/null +++ b/eagent/tools/registry.py @@ -0,0 +1,139 @@ +"""Tool registry and initialization.""" + +from __future__ import annotations + +from collections.abc import Callable + +from eagent.core.types import Tool +from eagent.tools.agent_tool import build_agent_tool +from eagent.tools.ask import build_ask_tool +from eagent.tools.bash import build_bash_tool +from eagent.tools.edit import build_edit_tool +from eagent.tools.glob_tool import build_glob_tool +from eagent.tools.grep_tool import build_grep_tool +from eagent.tools.notebook_edit import build_notebook_edit_tool +from eagent.tools.plan_mode import build_enter_plan_mode_tool, build_exit_plan_mode_tool +from eagent.tools.read import build_read_tool +from eagent.tools.todo import build_todo_tool +from eagent.tools.web_fetch import build_web_fetch_tool +from eagent.tools.web_search import build_web_search_tool +from eagent.tools.write import build_write_tool + +_registry: dict[str, Tool] = {} +_initialized = False + + +def build_tool(definition: Tool) -> Tool: + return Tool( + name=definition.name, + description=definition.description, + input_schema=definition.input_schema, + call=definition.call, + prompt=definition.prompt or (lambda: ""), + is_concurrency_safe=definition.is_concurrency_safe or (lambda _i: False), + is_read_only=definition.is_read_only or (lambda _i: False), + max_result_size_chars=definition.max_result_size_chars or 30_000, + user_facing_name=definition.user_facing_name or (lambda _i: definition.name), + ) + + +def register_tool(tool: Tool) -> None: + _registry[tool.name] = build_tool(tool) + + +def get_all_tools() -> list[Tool]: + return list(_registry.values()) + + +def get_tool_by_name(name: str) -> Tool | None: + return _registry.get(name) + + +def has_tool_by_name(name: str) -> bool: + return name in _registry + + +def get_tool_count() -> int: + return len(_registry) + + +async def initialize_tools(cwd: str | None = None) -> list[Tool]: + global _initialized + if _initialized and _registry: + return get_all_tools() + + builders: list[Callable[[], Tool]] = [ + build_bash_tool, + build_read_tool, + build_edit_tool, + build_write_tool, + build_glob_tool, + build_grep_tool, + build_agent_tool, + build_ask_tool, + build_todo_tool, + build_web_fetch_tool, + build_web_search_tool, + build_enter_plan_mode_tool, + build_exit_plan_mode_tool, + build_notebook_edit_tool, + ] + + for factory in builders: + register_tool(factory()) + + # Optional skill tool. + try: + from eagent.skills.skill_tool import build_skill_tool, initialize_skills + + await initialize_skills(cwd or ".") + register_tool(build_skill_tool()) + except Exception: + pass + + _initialized = True + return get_all_tools() + + +def register_dynamic_tools(tools: list[Tool]) -> None: + for tool in tools: + register_tool(tool) + + +def reset_registry() -> None: + global _initialized + _registry.clear() + _initialized = False + + +def get_read_only_tool_names() -> list[str]: + names: list[str] = [] + for tool in get_all_tools(): + try: + if tool.is_read_only({}): + names.append(tool.name) + except Exception: + continue + return names + + +def filter_tools_by_name(names: list[str]) -> list[Tool]: + allow = set(names) + return [tool for tool in get_all_tools() if tool.name in allow] + + +def exclude_tools_by_name(names: list[str]) -> list[Tool]: + deny = set(names) + return [tool for tool in get_all_tools() if tool.name not in deny] + + +def generate_tool_summary() -> str: + tools = get_all_tools() + if not tools: + return "(No tools registered)" + + lines = ["Available tools:"] + for tool in tools: + desc = tool.description({}) if callable(tool.description) else tool.description + lines.append(f" - {tool.name}: {desc}") + return "\n".join(lines) diff --git a/eagent/tools/todo.py b/eagent/tools/todo.py new file mode 100644 index 0000000..aae28af --- /dev/null +++ b/eagent/tools/todo.py @@ -0,0 +1,92 @@ +"""Todo tool for lightweight task tracking.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from eagent.core.types import Tool, ToolContext, ToolResult +from eagent.paths import env_root + + +def _todo_file(session_id: str) -> Path: + path = env_root() / "todo" + path.mkdir(parents=True, exist_ok=True) + return path / f"{session_id}.json" + + +def _load(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + return [] + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return [] + return data if isinstance(data, list) else [] + + +def _save(path: Path, items: list[dict[str, Any]]) -> None: + path.write_text(json.dumps(items, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + action = str(input_data.get("action") or "list") + text = str(input_data.get("text") or "").strip() + idx = input_data.get("index") + + path = _todo_file(context.session_id) + items = _load(path) + + if action == "add": + if not text: + return ToolResult(result="Error: text is required for add action.", is_error=True) + items.append({"text": text, "done": False}) + _save(path, items) + return ToolResult(result=f"Added todo #{len(items)}: {text}") + + if action == "done": + if not isinstance(idx, int) or idx < 1 or idx > len(items): + return ToolResult( + result="Error: valid index is required for done action.", is_error=True + ) + items[idx - 1]["done"] = True + _save(path, items) + return ToolResult(result=f"Marked todo #{idx} done.") + + if action == "remove": + if not isinstance(idx, int) or idx < 1 or idx > len(items): + return ToolResult( + result="Error: valid index is required for remove action.", is_error=True + ) + removed = items.pop(idx - 1) + _save(path, items) + return ToolResult(result=f"Removed todo #{idx}: {removed.get('text', '')}") + + lines = [] + for i, item in enumerate(items, start=1): + mark = "x" if item.get("done") else " " + lines.append(f"{i}. [{mark}] {item.get('text', '')}") + return ToolResult(result="\n".join(lines) if lines else "No todos.") + + +def build_todo_tool() -> Tool: + return Tool( + name="Todo", + description="Manage a simple per-session todo list.", + input_schema={ + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["list", "add", "done", "remove"]}, + "text": {"type": "string"}, + "index": {"type": "integer", "minimum": 1}, + }, + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Track progress with add/done/remove/list todo actions.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=20_000, + user_facing_name=lambda _i: "Todo", + ) diff --git a/eagent/tools/web_fetch.py b/eagent/tools/web_fetch.py new file mode 100644 index 0000000..1c2c298 --- /dev/null +++ b/eagent/tools/web_fetch.py @@ -0,0 +1,64 @@ +"""Web fetch tool.""" + +from __future__ import annotations + +from typing import Any + +import httpx + +from eagent.core.types import Tool, ToolContext, ToolResult + +MAX_CHARS = 40_000 + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + _ = context + url = str(input_data.get("url") or "").strip() + if not url: + return ToolResult(result="Error: url is required.", is_error=True) + timeout = float(input_data.get("timeout") or 20) + max_chars = int(input_data.get("max_chars") or MAX_CHARS) + + try: + async with httpx.AsyncClient(follow_redirects=True, timeout=timeout) as client: + response = await client.get(url) + except Exception as exc: + return ToolResult(result=f"Error fetching URL: {exc}", is_error=True) + + content_type = response.headers.get("content-type", "") + text = response.text + if len(text) > max_chars: + text = text[:max_chars] + f"\n\n[Truncated to {max_chars} chars]" + + return ToolResult( + result=( + f"URL: {response.url}\n" + f"Status: {response.status_code}\n" + f"Content-Type: {content_type}\n\n" + f"{text}" + ), + is_error=response.status_code >= 400, + ) + + +def build_web_fetch_tool() -> Tool: + return Tool( + name="WebFetch", + description="Fetch and return text from a URL.", + input_schema={ + "type": "object", + "properties": { + "url": {"type": "string"}, + "timeout": {"type": "number"}, + "max_chars": {"type": "integer", "minimum": 1000}, + }, + "required": ["url"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Fetch a specific URL when precise page content is needed.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: True, + max_result_size_chars=80_000, + user_facing_name=lambda input_data: f"WebFetch: {input_data.get('url')}", + ) diff --git a/eagent/tools/web_search.py b/eagent/tools/web_search.py new file mode 100644 index 0000000..ebab18e --- /dev/null +++ b/eagent/tools/web_search.py @@ -0,0 +1,238 @@ +"""Web search tool using lightweight DuckDuckGo HTML endpoint.""" + +from __future__ import annotations + +import base64 +import html +import os +import re +from typing import Any +from urllib.parse import parse_qs, unquote, urlencode, urlparse + +import httpx + +from eagent.core.types import Tool, ToolContext, ToolResult + +RESULT_LIMIT = 10 +DEFAULT_SEARCH_PROVIDER = "auto" +DEFAULT_SEARCH_USER_AGENT = "Mozilla/5.0" +BING_SEARCH_URL = "https://www.bing.com/search" +DDG_SEARCH_URL = "https://duckduckgo.com/html/" +VALID_PROVIDERS = {"auto", "ddg", "bing"} +SearchRows = list[tuple[str, str, str]] + + +def _strip_tags(text: str) -> str: + return re.sub(r"<[^>]+>", "", text) + + +def _normalize_ddg_href(href: str) -> str: + normalized = href.strip() + if normalized.startswith("//"): + normalized = "https:" + normalized + + parsed = urlparse(normalized) + if "duckduckgo.com" in parsed.netloc and parsed.path.startswith("/l/"): + target = parse_qs(parsed.query).get("uddg") + if target: + return unquote(target[0]) + return normalized + + +def _normalize_bing_href(href: str) -> str: + normalized = href.strip() + parsed = urlparse(normalized) + if "bing.com" not in parsed.netloc or parsed.path != "/ck/a": + return normalized + + encoded = parse_qs(parsed.query).get("u") + if not encoded: + return normalized + token = encoded[0] + if token.startswith("a1"): + token = token[2:] + + token = token.replace("-", "+").replace("_", "/") + token += "=" * (-len(token) % 4) + try: + decoded = base64.b64decode(token).decode("utf-8") + except Exception: + return normalized + if decoded.startswith("http://") or decoded.startswith("https://"): + return decoded + return normalized + + +def _extract_ddg(content: str, limit: int) -> SearchRows: + rows = re.findall( + r']*class="[^"]*result__a[^"]*"[^>]*href="(?P[^"]+)"[^>]*>(?P.*?)</a>', + content, + flags=re.IGNORECASE | re.DOTALL, + ) + snippets = re.findall( + r'<(?:a|div)[^>]*class="[^"]*result__snippet[^"]*"[^>]*>(.*?)</(?:a|div)>', + content, + flags=re.IGNORECASE | re.DOTALL, + ) + + out: SearchRows = [] + for idx, (href, title_html) in enumerate(rows[:limit]): + title = html.unescape(_strip_tags(title_html)).strip() + snippet = html.unescape(_strip_tags(snippets[idx])).strip() if idx < len(snippets) else "" + out.append((title, _normalize_ddg_href(href), snippet)) + return out + + +def _extract_bing(content: str, limit: int) -> SearchRows: + pattern = r'<li[^>]*class="[^"]*\bb_algo\b[^"]*"[^>]*>' + item_starts = [m.start() for m in re.finditer(pattern, content, re.IGNORECASE)] + if not item_starts: + return [] + + out: SearchRows = [] + for idx, start in enumerate(item_starts): + end = item_starts[idx + 1] if idx + 1 < len(item_starts) else len(content) + item = content[start:end] + header = re.search( + r'<h2[^>]*>\s*<a[^>]*href="(?P<href>[^"]+)"[^>]*>(?P<title>.*?)</a>', + item, + flags=re.IGNORECASE | re.DOTALL, + ) + if not header: + continue + href = _normalize_bing_href(html.unescape(header.group("href")).strip()) + title = html.unescape(_strip_tags(header.group("title"))).strip() + snippet_match = re.search(r"<p[^>]*>(.*?)</p>", item, flags=re.IGNORECASE | re.DOTALL) + snippet = ( + html.unescape(_strip_tags(snippet_match.group(1))).strip() if snippet_match else "" + ) + out.append((title, href, snippet)) + if len(out) >= limit: + break + return out + + +def _format_results(rows: SearchRows) -> str: + lines: list[str] = [] + for idx, (title, href, snippet) in enumerate(rows, start=1): + lines.append(f"{idx}. {title}\n URL: {href}") + if snippet: + lines.append(f" Snippet: {snippet}") + return "\n".join(lines) + + +def _resolve_provider(input_data: dict[str, Any]) -> tuple[str | None, str | None]: + raw = str( + input_data.get("provider") + or os.getenv("ENV_AGENT_WEB_SEARCH_PROVIDER", DEFAULT_SEARCH_PROVIDER) + ).strip() + provider = raw.lower() + if provider in VALID_PROVIDERS: + return provider, None + valid = ", ".join(sorted(VALID_PROVIDERS)) + return None, f"Error: invalid provider '{raw}'. Valid values: {valid}." + + +def _build_client_headers() -> dict[str, str]: + user_agent = os.getenv("ENV_AGENT_WEB_SEARCH_USER_AGENT", DEFAULT_SEARCH_USER_AGENT).strip() + if not user_agent: + user_agent = DEFAULT_SEARCH_USER_AGENT + return {"User-Agent": user_agent} + + +async def _search_ddg( + client: httpx.AsyncClient, ddg_url: str, limit: int +) -> tuple[SearchRows | None, str | None]: + response = await client.get(ddg_url) + if response.status_code >= 400: + return None, f"DuckDuckGo search failed with status {response.status_code}." + return _extract_ddg(response.text, limit), None + + +async def _search_bing( + client: httpx.AsyncClient, bing_url: str, limit: int +) -> tuple[SearchRows | None, str | None]: + response = await client.get(bing_url) + if response.status_code >= 400: + return None, f"Bing search failed with status {response.status_code}." + return _extract_bing(response.text, limit), None + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + _ = context + query = str(input_data.get("query") or "").strip() + if not query: + return ToolResult(result="Error: query is required.", is_error=True) + + limit = int(input_data.get("limit") or RESULT_LIMIT) + limit = max(1, min(limit, 20)) + provider, provider_error = _resolve_provider(input_data) + if provider_error: + return ToolResult(result=provider_error, is_error=True) + assert provider is not None + + params = urlencode({"q": query}) + ddg_url = f"{DDG_SEARCH_URL}?{params}" + bing_url = f"{BING_SEARCH_URL}?{params}" + errors: list[str] = [] + + try: + async with httpx.AsyncClient( + timeout=20, + follow_redirects=True, + headers=_build_client_headers(), + ) as client: + if provider in {"auto", "ddg"}: + try: + ddg_rows, ddg_error = await _search_ddg(client, ddg_url, limit) + except Exception as exc: + ddg_rows, ddg_error = None, f"DuckDuckGo search error: {exc}" + if ddg_rows: + return ToolResult(result=_format_results(ddg_rows)) + if ddg_error: + errors.append(ddg_error) + if provider == "ddg": + if errors: + return ToolResult(result=" ".join(errors), is_error=True) + return ToolResult(result="No search results parsed.") + + if provider in {"auto", "bing"}: + try: + bing_rows, bing_error = await _search_bing(client, bing_url, limit) + except Exception as exc: + bing_rows, bing_error = None, f"Bing search error: {exc}" + if bing_rows: + return ToolResult(result=_format_results(bing_rows)) + if bing_error: + errors.append(bing_error) + if provider == "bing" and errors: + return ToolResult(result=" ".join(errors), is_error=True) + except Exception as exc: + return ToolResult(result=f"Error searching web: {exc}", is_error=True) + + if errors and provider == "auto" and len(errors) >= 2: + return ToolResult(result=" ".join(errors), is_error=True) + return ToolResult(result="No search results parsed.") + + +def build_web_search_tool() -> Tool: + return Tool( + name="WebSearch", + description="Search the web for recent/public information.", + input_schema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer", "minimum": 1, "maximum": 20}, + "provider": {"type": "string", "enum": ["auto", "ddg", "bing"]}, + }, + "required": ["query"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use WebSearch when external current information is required.", + is_read_only=lambda _i: True, + is_concurrency_safe=lambda _i: True, + max_result_size_chars=60_000, + user_facing_name=lambda input_data: f"WebSearch: {input_data.get('query')}", + ) diff --git a/eagent/tools/write.py b/eagent/tools/write.py new file mode 100644 index 0000000..31de3c2 --- /dev/null +++ b/eagent/tools/write.py @@ -0,0 +1,79 @@ +"""Write tool for complete file overwrite/create.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from eagent.core.types import FileState, Tool, ToolContext, ToolResult +from eagent.files.atomic_write import atomic_write + + +async def _call(input_data: dict[str, Any], context: ToolContext) -> ToolResult: + raw_path = str( + input_data.get("file_path") or input_data.get("path") or input_data.get("filePath") or "" + ) + if not raw_path: + return ToolResult(result="Error: file_path parameter is required.", is_error=True) + + content = str(input_data.get("content") or "") + + file_path = Path(raw_path) + if not file_path.is_absolute(): + file_path = Path(context.cwd) / file_path + file_path = file_path.resolve() + + if file_path.exists(): + cached = context.read_file_state.get(str(file_path)) + if cached is None: + return ToolResult( + result=( + f"Error: file already exists at {file_path}. Read it first before overwrite, " + "or use Edit for targeted updates." + ), + is_error=True, + ) + current_mtime = file_path.stat().st_mtime * 1000 + if current_mtime > cached.timestamp + 1000: + return ToolResult( + result="Error: file changed after last read. Read again before write.", + is_error=True, + ) + + try: + atomic_write(str(file_path), content) + except PermissionError: + return ToolResult(result=f"Error: permission denied writing to {file_path}", is_error=True) + except Exception as exc: + return ToolResult(result=f"Error writing file: {exc}", is_error=True) + + context.modified_files.add(str(file_path)) + context.file_history.tracked_files.add(str(file_path)) + context.read_file_state.set( + str(file_path), + FileState(content=content, timestamp=file_path.stat().st_mtime * 1000), + ) + + return ToolResult(result=f"Written: {file_path} ({len(content.splitlines())} lines)") + + +def build_write_tool() -> Tool: + return Tool( + name="Write", + description="Write complete file content (create or overwrite).", + input_schema={ + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "content": {"type": "string"}, + }, + "required": ["file_path", "content"], + "additionalProperties": False, + }, + call=_call, + prompt=lambda: "Use Write for full-file creation/overwrite. Prefer Edit for targeted changes.", + is_read_only=lambda _i: False, + is_concurrency_safe=lambda _i: False, + max_result_size_chars=30_000, + user_facing_name=lambda input_data: f"Write: {input_data.get('file_path') or input_data.get('path')}", + ) diff --git a/eagent/tui/__init__.py b/eagent/tui/__init__.py new file mode 100644 index 0000000..acfcb8f --- /dev/null +++ b/eagent/tui/__init__.py @@ -0,0 +1,6 @@ +"""TUI components for eagent CLI.""" + +from eagent.tui.app import EnvAgentTui, TuiState + +__all__ = ["EnvAgentTui", "TuiState"] + diff --git a/eagent/tui/agent_picker.py b/eagent/tui/agent_picker.py new file mode 100644 index 0000000..c3ebac9 --- /dev/null +++ b/eagent/tui/agent_picker.py @@ -0,0 +1,80 @@ +"""Agent profile picker state and rendering.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + +from prompt_toolkit.formatted_text import FormattedText + + +@dataclass +class AgentPicker: + items: list[tuple[str, str]] = field(default_factory=list) + index: int = 0 + future: asyncio.Future[str | None] | None = None + + @property + def active(self) -> bool: + return self.future is not None and not self.future.done() + + def open(self, items: list[tuple[str, str]], loop: asyncio.AbstractEventLoop) -> None: + self.items = list(items) + self.index = 0 + self.future = loop.create_future() + + async def wait(self) -> str | None: + if self.future is None: + return None + return await self.future + + def close(self) -> None: + self.future = None + self.items = [] + self.index = 0 + + def move(self, step: int) -> None: + if not self.active or not self.items: + return + self.index = (self.index + step) % len(self.items) + + def confirm(self) -> None: + if not self.active or not self.items or self.future is None: + return + self.future.set_result(self.items[self.index][0]) + + def cancel(self) -> None: + if self.future is None or self.future.done(): + return + self.future.set_result(None) + + def render(self) -> FormattedText: + if not self.active or not self.items: + return FormattedText([]) + + title = "Agent 配置列表(上下键移动,Enter 选择,Esc 取消)" + rows = [title] + rows.extend( + f"{'>' if idx == self.index else ' '} {label}" + for idx, (_name, label) in enumerate(self.items) + ) + inner_width = min(76, max(len(row) for row in rows)) + + def clip(text: str) -> str: + if len(text) <= inner_width: + return text + if inner_width <= 3: + return text[:inner_width] + return text[: inner_width - 3] + "..." + + border = "+-" + "-" * inner_width + "-+\n" + fragments: list[tuple[str, str]] = [("class:toolpanel.summary", border)] + fragments.append(("class:toolpanel.summary", f"| {clip(title).ljust(inner_width)} |\n")) + fragments.append(("class:toolpanel.summary", border)) + for idx, (_name, label) in enumerate(self.items): + marker = ">" if idx == self.index else " " + row_text = clip(f"{marker} {label}").ljust(inner_width) + style = "class:toolpanel.summary" if idx == self.index else "class:toolpanel.detail" + fragments.append((style, f"| {row_text} |\n")) + fragments.append(("class:toolpanel.summary", border)) + return FormattedText(fragments) diff --git a/eagent/tui/app.py b/eagent/tui/app.py new file mode 100644 index 0000000..330b6ba --- /dev/null +++ b/eagent/tui/app.py @@ -0,0 +1,1295 @@ +"""Prompt-toolkit full-screen TUI for eagent.""" + +from __future__ import annotations + +import asyncio +import contextlib +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any + +from prompt_toolkit.application import Application +from prompt_toolkit.completion import Completer +from prompt_toolkit.data_structures import Point +from prompt_toolkit.document import Document +from prompt_toolkit.filters import Condition, has_focus +from prompt_toolkit.formatted_text import FormattedText +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout import HSplit, Layout, Window +from prompt_toolkit.layout.containers import ConditionalContainer, Float, FloatContainer +from prompt_toolkit.layout.controls import FormattedTextControl +from prompt_toolkit.layout.dimension import D +from prompt_toolkit.layout.margins import ScrollbarMargin +from prompt_toolkit.layout.menus import CompletionsMenu +from prompt_toolkit.shortcuts import radiolist_dialog +from prompt_toolkit.widgets import TextArea + +from eagent.core.types import PermissionDecision, PermissionRule +from eagent.permissions.engine import add_session_rule +from eagent.tui.agent_picker import AgentPicker +from eagent.tui.status_bar import SPINNER_FRAMES, StatusBarRenderer, StatusBarState, StatusMeta +from eagent.tui.styles import TUI_STYLE +from eagent.utils.completer import extract_mention_query + + +@dataclass +class TuiState: + status: str = "idle" + status_text: str = "Ready" + spinner_index: int = 0 + busy: bool = False + input_mode: str = "compose" + + def on_submit(self) -> None: + self.status = "running" + self.status_text = "思考中..." + self.busy = True + + def on_tool_start(self) -> None: + self.status_text = "调用工具..." + + def on_compact(self) -> None: + self.status_text = "整理上下文..." + + def on_assistant_text(self) -> None: + if self.status == "running": + self.status_text = "输出中..." + + def on_error(self) -> None: + self.status = "error" + self.status_text = "出现错误" + self.busy = False + + def on_turn_complete(self) -> None: + self.status = "idle" + self.status_text = "Ready" + self.busy = False + + def on_history_view(self) -> None: + self.input_mode = "history_view" + + def on_edit_mode(self) -> None: + self.input_mode = "compose" + + def on_agent_picker_mode(self) -> None: + self.input_mode = "agent_picker" + + +@dataclass +class ToolPanel: + seq: int + tool_use_id: str + tool_name: str + input_preview: str + started_at: float + status: str = "running" + is_error: bool = False + duration_ms: int | None = None + result_preview: str = "" + result_detail: str = "" + expanded: bool = False + + +@dataclass +class ActivityItem: + seq: int + text: str + status: str = "running" + started_at: float = field(default_factory=time.monotonic) + ended_at: float | None = None + + +@dataclass +class TranscriptChunk: + style: str + text: str + + +class EnvAgentTui: + def __init__( + self, + session_id: str, + get_status_meta: Callable[[], StatusMeta], + on_prompt: Callable[ + [str, Callable[[str], None], Callable[[dict[str, Any]], None]], Awaitable[None] + ], + on_command: Callable[ + [str, Callable[[str], None], Callable[[dict[str, Any]], None]], Awaitable[bool] + ], + completer: Completer | None = None, + startup_messages: list[str] | None = None, + list_agents: Callable[[], list[tuple[str, str]]] | None = None, + on_agent_select: Callable[[str], Awaitable[str]] | None = None, + command_specs: list[dict[str, Any]] | None = None, + dev_mode: bool = False, + ) -> None: + self._session_id = session_id + self._get_status_meta = get_status_meta + self._on_prompt = on_prompt + self._on_command = on_command + self._startup_messages = startup_messages or [] + self._list_agents = list_agents + self._on_agent_select = on_agent_select + self._dev_mode = dev_mode + self._command_specs = { + str(spec.get("name", "")): spec for spec in (command_specs or []) if spec.get("name") + } + self._command_alias_to_name = self._build_command_alias_map(command_specs or []) + self.state = TuiState() + self._running = True + self._assistant_block_open = False + self._input_history: list[str] = [] + self._history_index: int | None = None + self._default_hint = ( + "Enter 发送 | Alt+Enter 换行 | Esc/Ctrl+C 中断 | " + "Tab 补全 | F2/F3/F4 工具面板 | F6 活动轨模式" + ) + if self._dev_mode: + self._default_hint += " | Ctrl-R 重载" + self._idle_short_hint = "Enter 发送 | /help" + if self._dev_mode: + self._idle_short_hint += " | Ctrl-R 重载" + self._hint_text = self._default_hint + self._tool_panels: list[ToolPanel] = [] + self._tool_panel_index_by_id: dict[str, int] = {} + self._selected_tool_panel: int = 0 + self._activities: list[ActivityItem] = [] + self._tool_activity_seq: dict[str, int] = {} + self._activity_seq: int = 0 + self._animation_tick: int = 0 + self._turn_started_at: float | None = None + self._model_wait_seq: int | None = None + self._stream_seq: int | None = None + self._assistant_started_in_turn: bool = False + self._assistant_streaming_cursor: bool = False + self._assistant_chunk_index: int | None = None + self._activity_compact_mode: bool = True + self._active_turn_task: asyncio.Task[None] | None = None + self._queued_inputs: list[str] = [] + self._status_tool_text: str | None = None + self._status_tool_running: bool = False + self._status_tool_error: bool = False + self._agent_picker = AgentPicker() + self._transcript_chunks: list[TranscriptChunk] = [] + self._transcript_cursor_line: int = 0 + self._transcript_cursor_col: int = 0 + self._suspend_mention_autocomplete_once: bool = False + self._status_bar_renderer = StatusBarRenderer() + + self.transcript = Window( + content=FormattedTextControl( + self._transcript_fragments, + get_cursor_position=self._transcript_cursor_position, + show_cursor=False, + ), + style="class:transcript", + wrap_lines=True, + right_margins=[ScrollbarMargin(display_arrows=False)], + ) + history_read_only = Condition( + lambda: self.state.input_mode in {"history_view", "agent_picker"} + ) + self.input = TextArea( + height=self._input_height_dimension, + prompt="> ", + style="class:input", + multiline=True, + wrap_lines=True, + completer=completer, + complete_while_typing=True, + read_only=history_read_only, + ) + self.input.buffer.on_text_changed += self._on_input_text_changed + + self.status_bar = Window( + content=FormattedTextControl(self._status_fragments), + height=D.exact(1), + style="class:status", + ) + self.activity_rail = Window( + content=FormattedTextControl(self._activity_fragments), + height=D(min=1, max=4), + style="class:activity", + ) + self._activity_rail_visible = Condition( + lambda: self.state.busy or self._running_activity_count() > 0 + ) + self.activity_rail_container = ConditionalContainer( + content=self.activity_rail, + filter=self._activity_rail_visible, + ) + self.tool_panel = Window( + content=FormattedTextControl(self._tool_panel_fragments), + height=D(min=1, max=8), + style="class:toolpanel", + ) + self._tool_panel_visible = Condition(lambda: bool(self._tool_panels)) + self.tool_panel_container = ConditionalContainer( + content=self.tool_panel, + filter=self._tool_panel_visible, + ) + self._hint_visible = Condition(lambda: self.app.output.get_size().columns >= 72) + self.hint_bar = Window( + content=FormattedTextControl(self._hint_fragments), + height=D.exact(1), + style="class:hint", + ) + self.hint_bar_container = ConditionalContainer( + content=self.hint_bar, + filter=self._hint_visible, + ) + self.agent_picker_panel = Window( + content=FormattedTextControl(self._agent_picker_fragments), + height=D(min=3, max=8), + style="class:toolpanel", + ) + self._agent_picker_visible = Condition(lambda: self._agent_picker.active) + self.agent_picker_container = ConditionalContainer( + content=self.agent_picker_panel, + filter=self._agent_picker_visible, + ) + + body = HSplit( + [ + self.transcript, + self.tool_panel_container, + self.activity_rail_container, + self.agent_picker_container, + self.input, + self.status_bar, + self.hint_bar_container, + ] + ) + root_container = FloatContainer( + content=body, + floats=[ + Float( + xcursor=True, + ycursor=True, + width=lambda: max(40, self.app.output.get_size().columns - 4), + content=CompletionsMenu(max_height=8, scroll_offset=1), + ) + ], + ) + kb = KeyBindings() + focus_input = has_focus(self.input) + + @kb.add("c-d") + def _exit(_event) -> None: + self._running = False + self.app.exit() + + @kb.add("c-c") + def _ctrl_c(event) -> None: + self._handle_ctrl_c() + event.app.invalidate() + + @kb.add("c-r", filter=focus_input & Condition(lambda: self._dev_mode)) + def _reload(event) -> None: + if self._trigger_dev_reload(): + event.app.invalidate() + + @kb.add("enter", filter=focus_input) + def _enter(event) -> None: + if self._agent_picker.active: + self._confirm_agent_picker() + event.app.invalidate() + return + if self._accept_completion_selection(): + event.app.invalidate() + return + self._accept_input(self.input.buffer) + event.app.invalidate() + + @kb.add("escape", "enter", filter=focus_input) + def _alt_enter(event) -> None: + self._insert_newline(event) + + @kb.add("escape", filter=Condition(lambda: self.state.busy)) + def _interrupt(event) -> None: + self._interrupt_active_turn() + event.app.invalidate() + + @kb.add("escape", filter=Condition(lambda: self._agent_picker.active)) + def _cancel_picker(event) -> None: + self._cancel_agent_picker() + event.app.invalidate() + + @kb.add("up", filter=focus_input) + def _up(event) -> None: + if self._agent_picker.active: + self._move_agent_picker(step=-1) + event.app.invalidate() + return + if self._move_completion_selection(step=-1): + event.app.invalidate() + return + self._browse_history(step=-1) + event.app.invalidate() + + @kb.add("down", filter=focus_input) + def _down(event) -> None: + if self._agent_picker.active: + self._move_agent_picker(step=1) + event.app.invalidate() + return + if self._move_completion_selection(step=1): + event.app.invalidate() + return + self._browse_history(step=1) + event.app.invalidate() + + @kb.add( + "left", + filter=focus_input & Condition(lambda: self.state.input_mode == "history_view"), + ) + def _left(event) -> None: + self._enter_edit_mode_from_history(direction="left") + event.app.invalidate() + + @kb.add( + "right", + filter=focus_input & Condition(lambda: self.state.input_mode == "history_view"), + ) + def _right(event) -> None: + self._enter_edit_mode_from_history(direction="right") + event.app.invalidate() + + @kb.add("f2") + def _toggle_tool(event) -> None: + self._toggle_selected_tool_panel() + event.app.invalidate() + + @kb.add("f3") + def _prev_tool(event) -> None: + self._move_tool_panel_selection(step=-1) + event.app.invalidate() + + @kb.add("f4") + def _next_tool(event) -> None: + self._move_tool_panel_selection(step=1) + event.app.invalidate() + + @kb.add("f6") + def _toggle_activity_mode(event) -> None: + self._activity_compact_mode = not self._activity_compact_mode + event.app.invalidate() + + @kb.add("tab") + def _tab(event) -> None: + buffer = self.input.buffer + if buffer.complete_state: + buffer.complete_next() + else: + buffer.start_completion(select_first=True) + event.app.invalidate() + + @kb.add("s-tab") + def _shift_tab(event) -> None: + buffer = self.input.buffer + if buffer.complete_state: + buffer.complete_previous() + else: + buffer.start_completion(select_first=False) + event.app.invalidate() + + self.app = Application( + layout=Layout(root_container, focused_element=self.input), + key_bindings=kb, + full_screen=False, + style=TUI_STYLE, + ) + self._refresh_hint() + + @staticmethod + def _build_command_alias_map(command_specs: list[dict[str, Any]]) -> dict[str, str]: + alias_map: dict[str, str] = {} + for spec in command_specs: + name = str(spec.get("name", "")).strip() + if not name: + continue + aliases_raw = spec.get("aliases", []) + if not isinstance(aliases_raw, list): + continue + for alias in aliases_raw: + alias_text = str(alias).strip() + if not alias_text: + continue + alias_map[alias_text] = name + return alias_map + + async def run(self) -> None: + for message in self._startup_messages: + self._append_line(f"Warning: {message}") + spinner_task = self.app.create_background_task(self._spin()) + try: + await self.app.run_async() + finally: + spinner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await spinner_task + + async def prompt_permission( + self, tool: str, input_data: Any, message: str + ) -> PermissionDecision: + pending = self._start_activity(f"权限请求: {tool}") + preview = str(input_data)[:300] + selected = await radiolist_dialog( + title=f"Permission Required: {tool}", + text=f"{message}\n\nInput preview:\n{preview}\n\nChoose action:", + values=[ + ("allow_once", "Allow once"), + ("allow_session", "Always allow this tool in current session"), + ("deny_once", "Deny once"), + ("deny_session", "Always deny this tool in current session"), + ], + ).run_async() + + if selected == "allow_once": + self._finish_activity(pending.seq, status="done", suffix=" 已允许") + self._append_line(f"Permission: allow {tool}") + self.app.invalidate() + return PermissionDecision(behavior="allow") + + if selected == "allow_session": + add_session_rule(PermissionRule(tool=tool, behavior="allow", source="session")) + self._finish_activity(pending.seq, status="done", suffix=" 已允许(会话记忆)") + self._append_line(f"Permission: allow {tool} (session)") + self.app.invalidate() + return PermissionDecision(behavior="allow") + + if selected == "deny_session": + add_session_rule(PermissionRule(tool=tool, behavior="deny", source="session")) + self._finish_activity(pending.seq, status="error", suffix=" 已拒绝(会话记忆)") + self._append_line(f"Permission: deny {tool} (session)") + self.app.invalidate() + return PermissionDecision(behavior="deny", message=f"Denied by user for {tool}") + + self._finish_activity(pending.seq, status="error", suffix=" 已拒绝") + self._append_line(f"Permission: deny {tool}") + self.app.invalidate() + return PermissionDecision(behavior="deny", message=f"Denied by user for {tool}") + + def _start_activity(self, text: str) -> ActivityItem: + self._activity_seq += 1 + item = ActivityItem(seq=self._activity_seq, text=text) + self._activities.append(item) + if len(self._activities) > 24: + self._activities = self._activities[-24:] + return item + + def _finish_activity(self, seq: int | None, status: str = "done", suffix: str = "") -> None: + if seq is None: + return + item = next((activity for activity in self._activities if activity.seq == seq), None) + if item is None or item.status != "running": + return + item.status = status + item.ended_at = time.monotonic() + if suffix: + item.text = f"{item.text}{suffix}" + + def _finish_running_activities(self, status: str = "done") -> None: + for item in self._activities: + if item.status == "running": + item.status = status + item.ended_at = time.monotonic() + + def _reset_turn_feedback(self) -> None: + self._activities = [] + self._tool_activity_seq = {} + self._activity_seq = 0 + self._turn_started_at = time.monotonic() + self._model_wait_seq = None + self._stream_seq = None + self._assistant_started_in_turn = False + self._assistant_streaming_cursor = False + self._assistant_block_open = False + self._assistant_chunk_index = None + self._status_tool_text = None + self._status_tool_running = False + self._status_tool_error = False + + def _mark_turn_complete(self, failed: bool = False, message: str | None = None) -> None: + self._finish_activity(self._model_wait_seq, status="done") + self._finish_activity(self._stream_seq, status="done") + self._finish_running_activities(status="error" if failed else "done") + if self._turn_started_at is None: + return + elapsed = time.monotonic() - self._turn_started_at + if message: + item = self._start_activity(message) + item.status = "error" if failed else "done" + item.ended_at = time.monotonic() + return + if failed: + item = self._start_activity(f"回合结束(错误) {elapsed:.1f}s") + item.status = "error" + item.ended_at = time.monotonic() + return + finished = self._start_activity(f"回合完成 {elapsed:.1f}s") + finished.status = "done" + finished.ended_at = time.monotonic() + + def _running_activity_count(self) -> int: + return sum(1 for activity in self._activities if activity.status == "running") + + def _handle_ctrl_c(self) -> None: + if self.state.busy: + self._interrupt_active_turn(source="Ctrl+C") + return + self._running = False + self.app.exit() + + def _interrupt_active_turn(self, source: str = "Esc") -> None: + if not self.state.busy: + return + task = self._active_turn_task + if task is None or task.done(): + return + self._append_line(f"System: Interrupt requested ({source})") + self._start_activity(f"用户中断当前回合 ({source})") + task.cancel() + + def _shimmer_segments( + self, text: str, base_style: str, shine_style: str, width: int = 7 + ) -> list[tuple[str, str]]: + if not text: + return [] + cycle = len(text) + width + head = self._animation_tick % cycle + start = head - width + fragments: list[tuple[str, str]] = [] + for idx, char in enumerate(text): + style = shine_style if start <= idx <= head else base_style + if fragments and fragments[-1][0] == style: + prev_style, prev_text = fragments[-1] + fragments[-1] = (prev_style, prev_text + char) + else: + fragments.append((style, char)) + return fragments + + def _activity_fragments(self) -> FormattedText: + fragments: list[tuple[str, str]] = [] + + # AI 流式输出动态指示器(左侧) + if self._assistant_streaming_cursor and self._assistant_started_in_turn: + spinner_char = SPINNER_FRAMES[self.state.spinner_index] + fragments.append(("class:activity.ai_streaming", f" {spinner_char} AI 正在输出 ")) + + if not self._activities: + label = " Activity[min]: idle" if self._activity_compact_mode else " Activity: idle" + fragments.append(("class:activity.empty", label)) + return FormattedText(fragments) + + lines = self._activities[-1:] if self._activity_compact_mode else self._activities[-4:] + if self._activity_compact_mode: + fragments.append(("class:activity.empty", " | ")) + for item in lines: + if item.status == "running": + prefix = " … " + fragments.append(("class:activity.running", prefix)) + shimmer = self._shimmer_segments( + item.text, + base_style="class:activity.running", + shine_style="class:activity.shine", + ) + fragments.extend(shimmer) + fragments.append(("class:activity.running", "\n")) + continue + + prefix = " ✓ " if item.status == "done" else " ! " + style = "class:activity.done" if item.status == "done" else "class:activity.error" + text = item.text + if item.ended_at is not None: + duration = int((item.ended_at - item.started_at) * 1000) + text = f"{text} ({duration}ms)" + fragments.append((style, f"{prefix}{text}\n")) + + if self._activity_compact_mode and len(fragments) > 1: + # Keep compact mode to a single visible row. + text = "".join(chunk for _style, chunk in fragments).rstrip("\n") + return FormattedText([("class:activity.running", text)]) + + return FormattedText(fragments) + + def _status_fragments(self) -> FormattedText: + return self._status_bar_renderer.render( + StatusBarState( + columns=self.app.output.get_size().columns, + busy=self.state.busy, + status=self.state.status, + status_text=self.state.status_text, + input_mode=self.state.input_mode, + spinner_index=self.state.spinner_index, + running_activity_count=self._running_activity_count(), + tool_text=self._status_tool_text, + tool_running=self._status_tool_running, + tool_error=self._status_tool_error, + activity_compact_mode=self._activity_compact_mode, + queued_input_count=len(self._queued_inputs), + meta=self._get_status_meta(), + ) + ) + + def _tool_panel_fragments(self) -> FormattedText: + if not self._tool_panels: + return FormattedText( + [("class:toolpanel.empty", " Tool Panels: no tool calls in current turn")] + ) + + has_expanded = any(panel.expanded for panel in self._tool_panels) + if has_expanded: + panels = self._tool_panels + else: + selected = self._tool_panels[self._selected_tool_panel] + panels = [selected] + hidden_count = len(self._tool_panels) - len(panels) + + fragments: list[tuple[str, str]] = [] + for panel in panels: + idx = self._tool_panel_index_by_id.get(panel.tool_use_id, 0) + marker = ">" if idx == self._selected_tool_panel else " " + arrow = "▼" if panel.expanded else "▶" + status = panel.status + if panel.status == "done": + status = "error" if panel.is_error else "ok" + duration = f" {panel.duration_ms}ms" if panel.duration_ms is not None else "" + more = f" (+{hidden_count} more)" if hidden_count > 0 and not has_expanded else "" + summary = ( + f" {marker} {arrow} #{panel.seq} {panel.tool_name} " + f"{status}{duration} | {panel.input_preview}{more}\n" + ) + summary_class = "class:toolpanel.error" if panel.is_error else "class:toolpanel.summary" + fragments.append((summary_class, summary)) + + if panel.expanded: + detail = panel.result_detail or panel.result_preview or "(no output)" + fragments.append(("class:toolpanel.detail", f" {detail}\n")) + + return FormattedText(fragments) + + def _hint_fragments(self) -> FormattedText: + return FormattedText([("class:hint.text", f" {self._hint_text}")]) + + def _agent_picker_fragments(self) -> FormattedText: + return self._agent_picker.render() + + def _transcript_fragments(self) -> FormattedText: + if not self._transcript_chunks: + return FormattedText([]) + fragments: list[tuple[str, str]] = [ + (chunk.style, chunk.text) for chunk in self._transcript_chunks + ] + if self._should_show_idle_ai_prompt(): + if fragments and not fragments[-1][1].endswith("\n"): + fragments.append(("class:transcript.system", "\n")) + idle_icon = ">" if (self._animation_tick // 4) % 2 == 0 else " " + fragments.append(("class:transcript.ai_idle_prompt", f"{idle_icon} \n")) + return FormattedText(fragments) + + def _should_show_idle_ai_prompt(self) -> bool: + return ( + self.state.status == "idle" + and not self.state.busy + and not self._assistant_block_open + and not self._assistant_streaming_cursor + and not self._status_tool_running + and self._running_activity_count() == 0 + ) + + def _transcript_cursor_position(self) -> Point: + return Point(x=self._transcript_cursor_col, y=self._transcript_cursor_line) + + def _advance_transcript_cursor(self, text: str) -> None: + if not text: + return + line_breaks = text.count("\n") + if line_breaks == 0: + self._transcript_cursor_col += len(text) + return + self._transcript_cursor_line += line_breaks + self._transcript_cursor_col = len(text.rsplit("\n", 1)[-1]) + + def _append_transcript_text(self, text: str, style: str) -> None: + if not text: + return + if self._transcript_chunks and self._transcript_chunks[-1].style == style: + self._transcript_chunks[-1].text += text + else: + self._transcript_chunks.append(TranscriptChunk(style=style, text=text)) + self._advance_transcript_cursor(text) + + def _transcript_last_char(self) -> str | None: + for chunk in reversed(self._transcript_chunks): + if chunk.text: + return chunk.text[-1] + return None + + def _input_visible_lines(self) -> int: + if not hasattr(self, "input"): + return 1 + text = self.input.buffer.text + if not text: + return 1 + return text.count("\n") + 1 + + def _input_max_rows(self) -> int: + rows = 24 + if hasattr(self, "app"): + with contextlib.suppress(Exception): + rows = int(self.app.output.get_size().rows) + return max(4, rows // 3) + + def _input_height_dimension(self) -> D: + max_rows = self._input_max_rows() + preferred = min(max_rows, max(1, self._input_visible_lines())) + # Keep input height tied to content lines; don't consume extra free space. + return D(min=1, preferred=preferred, max=max_rows, weight=0) + + async def _spin(self) -> None: + while self._running: + if ( + self.state.busy + or self._running_activity_count() > 0 + or self._should_show_idle_ai_prompt() + ): + self.state.spinner_index = (self.state.spinner_index + 1) % len(SPINNER_FRAMES) + self._animation_tick = (self._animation_tick + 1) % 100_000 + self.app.invalidate() + await asyncio.sleep(0.12) + + def _accept_input(self, buffer) -> None: + raw = buffer.text + if not raw.strip(): + return + + text = raw.rstrip("\n") + if self.state.busy: + self._record_history(text) + self._history_index = None + self.state.on_edit_mode() + buffer.text = "" + self._queued_inputs.append(text) + queued = self._start_activity( + f"已排队后续输入(第 {len(self._queued_inputs)} 条)" + ) + queued.status = "done" + queued.ended_at = time.monotonic() + self._append_line( + f"System: queued follow-up ({len(self._queued_inputs)} pending)" + ) + self._refresh_hint() + self.app.invalidate() + return + + self._record_history(text) + self._history_index = None + self.state.on_edit_mode() + buffer.text = "" + self._refresh_hint() + self._active_turn_task = self.app.create_background_task(self._handle_submit(text)) + + def _record_history(self, text: str) -> None: + if not text.strip(): + return + if self._input_history and self._input_history[-1] == text: + return + self._input_history.append(text) + + def _set_input_text(self, text: str) -> None: + self.input.buffer.set_document( + Document(text=text, cursor_position=len(text)), + bypass_readonly=True, + ) + self._refresh_hint() + + def set_input_draft(self, text: str) -> None: + self.state.on_edit_mode() + self._set_input_text(text) + self.app.invalidate() + + def _browse_history(self, step: int) -> None: + if self.state.busy or not self._input_history: + return + + if self._history_index is None: + if step > 0: + return + self._history_index = len(self._input_history) - 1 + else: + next_index = self._history_index + step + if next_index < 0: + next_index = 0 + if next_index >= len(self._input_history): + self._history_index = None + self.state.on_edit_mode() + self._set_input_text("") + self.app.invalidate() + return + self._history_index = next_index + + self.state.on_history_view() + self._set_input_text(self._input_history[self._history_index]) + self.app.invalidate() + + def _enter_edit_mode_from_history(self, direction: str) -> None: + if self.state.input_mode != "history_view": + return + self.state.on_edit_mode() + if direction == "left": + self.input.buffer.cursor_left(count=1) + else: + self.input.buffer.cursor_right(count=1) + self._refresh_hint() + self.app.invalidate() + + def _insert_newline(self, event) -> None: + if self.state.input_mode == "history_view": + self.state.on_edit_mode() + event.current_buffer.insert_text("\n") + self._refresh_hint() + + def _trigger_dev_reload(self) -> bool: + if not self._dev_mode or self._agent_picker.active: + return False + if self.state.input_mode == "history_view": + self.state.on_edit_mode() + self._set_input_text("/reload") + self._accept_input(self.input.buffer) + return True + + def _move_completion_selection(self, step: int) -> bool: + buffer = self.input.buffer + state = buffer.complete_state + if state is None or not state.completions: + return False + if step < 0: + buffer.complete_previous(count=abs(step)) + elif step > 0: + buffer.complete_next(count=step) + return True + + def _accept_completion_selection(self) -> bool: + buffer = self.input.buffer + state = buffer.complete_state + if state is None or not state.completions: + return False + completion = state.current_completion or state.completions[0] + self._suspend_mention_autocomplete_once = True + buffer.apply_completion(completion) + self.state.on_edit_mode() + self._history_index = None + self._refresh_hint() + return True + + def _toggle_selected_tool_panel(self) -> None: + if not self._tool_panels: + return + panel = self._tool_panels[self._selected_tool_panel] + panel.expanded = not panel.expanded + + def _move_tool_panel_selection(self, step: int) -> None: + if not self._tool_panels: + return + count = len(self._tool_panels) + self._selected_tool_panel = (self._selected_tool_panel + step) % count + + def _move_agent_picker(self, step: int) -> None: + self._agent_picker.move(step) + + def _confirm_agent_picker(self) -> None: + self._agent_picker.confirm() + + def _cancel_agent_picker(self) -> None: + self._agent_picker.cancel() + + def _format_command_hint(self, name: str, spec: dict[str, Any]) -> str: + arg_hint = str(spec.get("argument_hint", "")).strip() + description = str(spec.get("description", "")).strip() + examples_raw = spec.get("examples", []) + examples = [str(item).strip() for item in examples_raw if str(item).strip()] + + head = f"/{name}" + if arg_hint: + head = f"{head} {arg_hint}" + if description: + head = f"{head} - {description}" + if examples: + head = f"{head} | e.g. {examples[0]}" + return head + + def _on_input_text_changed(self, _event: object) -> None: + self._refresh_hint() + if self._suspend_mention_autocomplete_once: + self._suspend_mention_autocomplete_once = False + return + self._maybe_autocomplete_mentions() + + def _maybe_autocomplete_mentions(self) -> None: + if self.state.input_mode != "compose": + return + + buffer = self.input.buffer + if buffer.completer is None or buffer.complete_state is not None: + return + + query = extract_mention_query(buffer.document.text_before_cursor) + if query is None: + return + + buffer.start_completion(select_first=False) + + def _refresh_hint(self) -> None: + text = self.input.buffer.text + default_hint = self._default_hint + if self.state.input_mode == "agent_picker": + default_hint = "选择 Agent:Up/Down 移动,Enter 确认,Esc 取消" + elif self.state.input_mode == "history_view": + default_hint = "历史浏览:Up/Down 浏览,Left/Right 进入编辑" + elif self.state.busy: + runtime_note = self._current_runtime_hint() + default_hint = f"运行中:{runtime_note} | Esc/Ctrl+C 中断当前回合 | Alt+Enter 换行" + elif not text: + default_hint = self._idle_short_hint + + if not text.startswith("/"): + self._hint_text = default_hint + return + + first_line = text.splitlines()[0] + raw = first_line[1:] + if not raw: + self._hint_text = "输入命令并按 Tab 补全,例如 /model 或 /agent" + return + + parts = raw.split(maxsplit=1) + command_token = parts[0].strip() + canonical_token = self._command_alias_to_name.get(command_token, command_token) + has_args = len(parts) > 1 or first_line.endswith(" ") + + if not has_args: + matches = [ + name for name in self._command_specs if name.startswith(canonical_token) + ] + alias_matches = [ + alias + for alias in self._command_alias_to_name + if alias.startswith(command_token) + ] + if len(matches) == 1 and not alias_matches: + command_name = matches[0] + self._hint_text = self._format_command_hint( + command_name, self._command_specs[command_name] + ) + return + if len(matches) == 1 and alias_matches: + command_name = matches[0] + self._hint_text = self._format_command_hint( + command_name, self._command_specs[command_name] + ) + return + if matches or alias_matches: + merged = {f"/{name}" for name in matches} + merged.update(f"/{alias}" for alias in alias_matches) + joined = ", ".join(sorted(merged)[:6]) + self._hint_text = f"命令候选:{joined}" + return + self._hint_text = "未知命令,输入 /help 查看命令列表" + return + + spec = self._command_specs.get(canonical_token) + if spec is None: + self._hint_text = "未知命令,输入 /help 查看命令列表" + return + self._hint_text = self._format_command_hint(canonical_token, spec) + + def _current_runtime_hint(self) -> str: + if self._tool_panels: + panel = self._tool_panels[self._selected_tool_panel] + if panel.status == "running": + return f"工具 {panel.tool_name} 运行中" + if panel.is_error: + return f"工具 {panel.tool_name} 失败" + return f"工具 {panel.tool_name} 完成" + + running_activity = next( + (item for item in reversed(self._activities) if item.status == "running"), + None, + ) + if running_activity is not None: + return running_activity.text + if self._activities: + return self._activities[-1].text + return "处理中" + + def _format_tool_input_preview(self, input_data: Any) -> str: + if isinstance(input_data, dict): + command = input_data.get("command") + if isinstance(command, str) and command.strip(): + return command.strip()[:80] + file_path = ( + input_data.get("file_path") + or input_data.get("path") + or input_data.get("filePath") + or input_data.get("filename") + ) + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip()[:80] + return str(input_data)[:80] + + def _shorten_result(self, text: str, limit: int = 300) -> str: + cleaned = text.strip() + if len(cleaned) <= limit: + return cleaned + return cleaned[:limit] + " ..." + + def _register_tool_start(self, event: dict[str, Any]) -> ToolPanel: + tool_use_id = str(event.get("tool_use_id") or f"tool-{len(self._tool_panels)+1}") + panel = ToolPanel( + seq=len(self._tool_panels) + 1, + tool_use_id=tool_use_id, + tool_name=str(event.get("tool_name") or "Tool"), + input_preview=self._format_tool_input_preview(event.get("input", {})), + started_at=time.monotonic(), + expanded=False, + ) + self._tool_panel_index_by_id[tool_use_id] = len(self._tool_panels) + self._tool_panels.append(panel) + self._selected_tool_panel = len(self._tool_panels) - 1 + return panel + + def _register_tool_result(self, event: dict[str, Any]) -> ToolPanel: + tool_use_id = str(event.get("tool_use_id") or "") + panel_index = self._tool_panel_index_by_id.get(tool_use_id) + if panel_index is None: + panel = self._register_tool_start( + { + "tool_use_id": tool_use_id or f"tool-{len(self._tool_panels)+1}", + "tool_name": str(event.get("tool_name") or "Tool"), + "input": {}, + } + ) + panel_index = self._tool_panel_index_by_id[panel.tool_use_id] + panel = self._tool_panels[panel_index] + + panel.status = "done" + panel.is_error = bool(event.get("is_error", False)) + panel.duration_ms = int((time.monotonic() - panel.started_at) * 1000) + raw_result = str(event.get("result") or "") + panel.result_preview = self._shorten_result(raw_result, limit=120) + panel.result_detail = self._shorten_result(raw_result, limit=1200) + if panel.is_error: + panel.expanded = True + return panel + + async def _handle_submit(self, text: str) -> None: + current_text = text + try: + while True: + if current_text.strip() == "/agent": + await self._handle_agent_picker() + else: + self._close_assistant_block() + self._reset_turn_feedback() + self._tool_panels = [] + self._tool_panel_index_by_id = {} + self._selected_tool_panel = 0 + self.state.on_submit() + self._append_user_block(current_text) + self._start_activity("准备上下文") + self._refresh_hint() + self.app.invalidate() + + self._finish_running_activities(status="done") + if current_text.startswith("/"): + command_activity = self._start_activity("执行命令") + should_exit = await self._on_command( + current_text, self._append_line, self._on_event + ) + self._finish_activity(command_activity.seq, status="done") + if should_exit: + self._running = False + self.app.exit() + break + else: + self._model_wait_seq = self._start_activity("请求模型,等待首个响应").seq + await self._on_prompt(current_text, self._append_assistant, self._on_event) + self._close_assistant_block() + self.state.on_turn_complete() + self._mark_turn_complete() + self._refresh_hint() + self.app.invalidate() + + if not self._queued_inputs: + break + + queued_remaining = max(0, len(self._queued_inputs) - 1) + current_text = self._queued_inputs.pop(0) + self._append_line( + f"System: running queued follow-up ({queued_remaining} left after this)" + ) + self._refresh_hint() + self.app.invalidate() + except asyncio.CancelledError: + self._close_assistant_block(finalize=False) + self.state.on_turn_complete() + self._mark_turn_complete(failed=False, message="回合已中断") + if self._queued_inputs: + self._append_line("System: queued follow-ups cleared by interrupt") + self._queued_inputs.clear() + self._refresh_hint() + self.app.invalidate() + except Exception as exc: + self._close_assistant_block(finalize=False) + self._append_line(f"Error: {exc}") + self.state.on_error() + self._mark_turn_complete(failed=True) + if self._queued_inputs: + self._append_line("System: queued follow-ups cleared after error") + self._queued_inputs.clear() + self._refresh_hint() + self.app.invalidate() + finally: + self._active_turn_task = None + + async def _handle_agent_picker(self) -> None: + if self._list_agents is None or self._on_agent_select is None: + self._append_line("Error: /agent is not available in this mode.") + self.app.invalidate() + return + + profiles = self._list_agents() + if not profiles: + self._append_line("No agent profiles found in ~/.env/agent.json.") + self.app.invalidate() + return + + loop = asyncio.get_running_loop() + self._agent_picker.open(profiles, loop) + self.state.on_agent_picker_mode() + self._set_input_text("") + self._refresh_hint() + self.app.invalidate() + + selected = await self._agent_picker.wait() + + try: + if selected is None: + self._append_line("Agent selection cancelled.") + return + message = await self._on_agent_select(str(selected)) + self._append_line(message) + except Exception as exc: + self._append_line(f"Error: {exc}") + finally: + self._agent_picker.close() + self.state.on_edit_mode() + self._refresh_hint() + self.app.invalidate() + + def _on_event(self, event: dict[str, Any]) -> None: + event_type = str(event.get("type")) + if event_type == "tool_start": + self._close_assistant_block(finalize=False) + self.state.on_tool_start() + panel = self._register_tool_start(event) + activity = self._start_activity(f"调用工具 {panel.tool_name}") + self._tool_activity_seq[panel.tool_use_id] = activity.seq + self._status_tool_text = f"Tool#{panel.seq} {panel.tool_name} running" + self._status_tool_running = True + self._status_tool_error = False + elif event_type == "tool_result": + self._close_assistant_block(finalize=False) + panel = self._register_tool_result(event) + status = "failed" if panel.is_error else "done" + seq = self._tool_activity_seq.get(panel.tool_use_id) + self._finish_activity( + seq, + status="error" if panel.is_error else "done", + suffix=f" {status}", + ) + self._status_tool_text = f"Tool#{panel.seq} {panel.tool_name} {status}" + self._status_tool_running = False + self._status_tool_error = panel.is_error + elif event_type == "compact": + self._close_assistant_block(finalize=False) + self.state.on_compact() + compact_activity = self._start_activity("系统压缩上下文") + compact_activity.status = "done" + compact_activity.ended_at = time.monotonic() + self._append_line("System: Context compacted") + elif event_type == "error": + self._close_assistant_block(finalize=False) + self.state.on_error() + error_activity = self._start_activity(f"模型错误: {event.get('error')}") + error_activity.status = "error" + error_activity.ended_at = time.monotonic() + self._append_line(f"Error: {event.get('error')}") + elif event_type == "hook_debug": + text = str(event.get("text") or "").strip() + if text: + self._append_line(text) + elif event_type == "turn_complete": + self._close_assistant_block() + self.state.on_turn_complete() + self._refresh_hint() + self.app.invalidate() + + def _append_assistant(self, text: str) -> None: + self.state.on_assistant_text() + if not self._assistant_started_in_turn: + self._finish_activity(self._model_wait_seq, status="done") + self._stream_seq = self._start_activity("模型流式输出中").seq + self._assistant_started_in_turn = True + self._assistant_streaming_cursor = True + if not self._assistant_block_open: + if self._transcript_chunks and self._transcript_last_char() not in {None, "\n"}: + self._append_transcript_text("\n", style="class:transcript.system") + self._assistant_chunk_index = len(self._transcript_chunks) + self._transcript_chunks.append( + TranscriptChunk(style="class:assistant.pending", text="") + ) + self._assistant_block_open = True + + if self._assistant_chunk_index is not None: + chunk = self._transcript_chunks[self._assistant_chunk_index] + chunk.text += text + self._advance_transcript_cursor(text) + else: + self._append_transcript_text(text, style="class:assistant.pending") + self._refresh_hint() + self.app.invalidate() + + def _close_assistant_block(self, finalize: bool = True) -> None: + if not self._assistant_block_open: + return + if self._assistant_chunk_index is None: + self._assistant_block_open = False + return + chunk = self._transcript_chunks[self._assistant_chunk_index] + if chunk.text and not chunk.text.endswith("\n"): + chunk.text += "\n" + self._advance_transcript_cursor("\n") + if finalize: + chunk.style = "class:assistant" + self._assistant_block_open = False + self._assistant_chunk_index = None + if self._assistant_streaming_cursor: + self._assistant_streaming_cursor = False + + def _append_user_block(self, text: str) -> None: + self._close_assistant_block(finalize=True) + lines = text.splitlines() or [""] + if not lines: + return + if self._transcript_last_char() not in {None, "\n"}: + self._append_transcript_text("\n", style="class:transcript.system") + self._append_transcript_text("\n", style="class:transcript.system") + for line in lines: + self._append_transcript_text(f"> {line}\n", style="class:user") + + def _append_line(self, line: str, style: str = "class:transcript.system") -> None: + self._close_assistant_block(finalize=False) + self._append_transcript_text(line + "\n", style=style) diff --git a/eagent/tui/status_bar.py b/eagent/tui/status_bar.py new file mode 100644 index 0000000..1f62f52 --- /dev/null +++ b/eagent/tui/status_bar.py @@ -0,0 +1,138 @@ +"""Status bar rendering for the agent TUI.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +from prompt_toolkit.formatted_text import FormattedText + +SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏") + + +@dataclass(frozen=True) +class StatusMeta: + model: str + cwd: str = "" + git: str = "" + + +@dataclass(frozen=True) +class StatusBarState: + columns: int + busy: bool + status: str + status_text: str + input_mode: str + spinner_index: int + running_activity_count: int + tool_text: str | None + tool_running: bool + tool_error: bool + activity_compact_mode: bool + queued_input_count: int + meta: StatusMeta + + +class StatusBarRenderer: + def render(self, state: StatusBarState) -> FormattedText: + show_spinner = state.busy or state.running_activity_count > 0 or state.tool_running + spinner = SPINNER_FRAMES[state.spinner_index] if show_spinner else "•" + mode = "HISTORY" if state.input_mode == "history_view" else "EDIT" + activity_mode = "ACT:min" if state.activity_compact_mode else "ACT:full" + compact_idle = ( + not state.busy + and state.status == "idle" + and state.input_mode == "compose" + and state.queued_input_count == 0 + ) + fragments: list[tuple[str, str]] = [ + ("class:brand.rte", " RTE"), + ("class:brand.hyphen", "-"), + ("class:brand.ai", "AI "), + ("class:title", f"{spinner} {state.status_text} "), + ("class:dim", " | "), + ] + + if state.tool_text: + tool_icon = ( + SPINNER_FRAMES[state.spinner_index] + if state.tool_running + else "!" + if state.tool_error + else "✓" + ) + tool_style = "class:tool" if state.tool_running else "class:meta" + fragments.extend( + [ + (tool_style, f"{tool_icon} {state.tool_text} "), + ("class:dim", " | "), + ] + ) + + fragments.append(("class:meta", state.meta.model)) + used_width = sum(len(text) for _style, text in fragments) + + extras: list[tuple[str, str]] = [] + if state.columns >= 96 and state.meta.cwd: + path_budget = max(12, min(40, state.columns - used_width - 24)) + extras.extend( + [ + ("class:dim", " | "), + ("class:meta", self.shorten_path(state.meta.cwd, path_budget)), + ] + ) + if state.columns >= 118 and state.meta.git: + extra_width = sum(len(text) for _style, text in extras) + git_budget = max(10, min(28, state.columns - used_width - extra_width - 8)) + extras.extend( + [ + ("class:dim", " | "), + ("class:meta", self.clip_text(state.meta.git, git_budget)), + ] + ) + fragments.extend(extras) + + if not compact_idle and state.columns >= 72: + fragments.extend( + [ + ("class:dim", " | "), + ("class:meta", mode), + ("class:dim", " | "), + ("class:meta", activity_mode), + ] + ) + if state.queued_input_count: + fragments.extend( + [ + ("class:dim", " | "), + ("class:meta", f"Q:{state.queued_input_count}"), + ] + ) + return FormattedText(fragments) + + @staticmethod + def clip_text(text: str, max_width: int) -> str: + if max_width <= 0: + return "" + if len(text) <= max_width: + return text + if max_width <= 3: + return text[:max_width] + return text[: max_width - 1] + "…" + + @staticmethod + def shorten_path(path: str, max_width: int) -> str: + if max_width <= 0: + return "" + home = os.path.expanduser("~") + display = path + if display == home: + display = "~" + elif display.startswith(home + os.sep): + display = "~" + display[len(home) :] + if len(display) <= max_width: + return display + if max_width <= 4: + return display[-max_width:] + return "…" + display[-(max_width - 1) :] diff --git a/eagent/tui/styles.py b/eagent/tui/styles.py new file mode 100644 index 0000000..6200763 --- /dev/null +++ b/eagent/tui/styles.py @@ -0,0 +1,42 @@ +"""Style definitions for full-screen eagent TUI.""" + +from prompt_toolkit.styles import Style + +TUI_STYLE = Style.from_dict( + { + "status": "bg:#1f2a44 #f2f6ff", + "status.running": "bg:#1b4332 #e9ffef", + "status.error": "bg:#5c1d1d #ffecec", + "brand.rte": "#99f6e4", + "brand.hyphen": "#99f6e4", + "brand.ai": "bold #ffffff", + "logo.subtitle": "#8ea0bc", + "title": "bold #f3f7ff", + "meta": "#d7e3ff", + "input": "bg:#0f172a #f8fafc", + "hint": "bg:#111827 #cbd5e1", + "hint.text": "#cbd5e1", + "activity": "bg:#0f1724 #9fb0c8", + "activity.empty": "bg:#0f1724 #7d8ea8", + "activity.running": "bg:#0f1724 #99a9bf", + "activity.shine": "bg:#0f1724 #d9e3f2 bold", + "activity.done": "bg:#0f1724 #7f8fa5", + "activity.error": "bg:#201616 #ffcaca", + "activity.ai_streaming": "bg:#0f1724 #22d3ee bold", + "toolpanel": "bg:#101828 #d8e3ff", + "toolpanel.empty": "bg:#101828 #8ea0c2", + "toolpanel.summary": "bg:#101828 #d8e3ff", + "toolpanel.detail": "bg:#0d1428 #c6d4f5", + "toolpanel.error": "bg:#2b1a1a #ffcccc", + "transcript": "bg:#0b1020 #e6edf7", + "transcript.streaming_cursor": "#22d3ee nounderline blink", + "transcript.system": "#6b7a90", + "transcript.ai_idle_prompt": "#8ecbff blink", + "user": "#8ecbff", + "assistant.pending": "#74869f", + "assistant": "#f8fafc", + "tool": "#facc15", + "error": "bold #f87171", + "dim": "#94a3b8", + } +) diff --git a/eagent/utils/__init__.py b/eagent/utils/__init__.py new file mode 100644 index 0000000..1015baf --- /dev/null +++ b/eagent/utils/__init__.py @@ -0,0 +1,27 @@ +"""Utility helpers.""" + +from eagent.utils.completer import SlashCommandCompleter, build_completer +from eagent.utils.cost import create_cost_tracker, format_token_count, format_usd, summarize_cost +from eagent.utils.format import blue, bold, dim, gold, green, red, yellow +from eagent.utils.process import ProcessResult, run_process +from eagent.utils.streaming import collect_assistant_text, event_to_log_line + +__all__ = [ + "blue", + "green", + "yellow", + "red", + "dim", + "bold", + "gold", + "ProcessResult", + "run_process", + "create_cost_tracker", + "format_token_count", + "format_usd", + "summarize_cost", + "SlashCommandCompleter", + "build_completer", + "collect_assistant_text", + "event_to_log_line", +] diff --git a/eagent/utils/completer.py b/eagent/utils/completer.py new file mode 100644 index 0000000..f73907f --- /dev/null +++ b/eagent/utils/completer.py @@ -0,0 +1,317 @@ +"""Prompt toolkit completer for slash commands and @mentions.""" + +from __future__ import annotations + +import os +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from pathlib import Path + +from prompt_toolkit.completion import Completer, Completion +from prompt_toolkit.document import Document + +from eagent.commands.registry import get_command_info_list + +_IGNORED_DIR_NAMES = { + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".venv", + ".worktrees", + "__pycache__", + "node_modules", +} +_MAX_PATH_SUGGESTIONS = 4000 +_MAX_COMPLETION_ITEMS = 80 + + +@dataclass(frozen=True) +class ResumeSuggestion: + value: str + display: str + meta: str = "recent session" + + +def extract_mention_query(text: str) -> str | None: + at_index = text.rfind("@") + if at_index < 0: + return None + if at_index > 0: + prev_char = text[at_index - 1] + if prev_char.isascii() and (prev_char.isalnum() or prev_char == "_"): + # Avoid triggering file mentions when typing ascii email-like tokens. + return None + query = text[at_index + 1 :] + if any(ch.isspace() for ch in query): + return None + return query + + +class SlashCommandCompleter(Completer): + def __init__( + self, + model_suggestions: Callable[[], list[str]] | None = None, + resume_suggestions: Callable[[], list[str | ResumeSuggestion]] | None = None, + file_suggestions: Callable[[], list[str]] | None = None, + workspace_root: str | None = None, + command_specs: list[dict[str, object]] | None = None, + ) -> None: + self._commands = command_specs if command_specs is not None else get_command_info_list() + self._model_suggestions = model_suggestions or (lambda: []) + self._resume_suggestions = resume_suggestions or (lambda: []) + self._workspace_root = Path(workspace_root).resolve() if workspace_root else Path.cwd() + self._file_suggestions = file_suggestions or self._scan_workspace_paths + + def _model_values(self) -> list[str]: + names = [name.strip() for name in self._model_suggestions() if name.strip()] + deduped: list[str] = [] + seen: set[str] = set() + for name in names: + if name in seen: + continue + seen.add(name) + deduped.append(name) + return deduped + + def _resume_values(self) -> list[ResumeSuggestion]: + values: list[ResumeSuggestion] = [] + for raw in self._resume_suggestions(): + if isinstance(raw, ResumeSuggestion): + value = raw.value.strip() + display = raw.display.strip() or value + meta = raw.meta.strip() or "recent session" + else: + value = str(raw).strip() + display = value + meta = "recent session" + if value: + values.append(ResumeSuggestion(value=value, display=display, meta=meta)) + + deduped: list[ResumeSuggestion] = [] + seen: set[str] = set() + for suggestion in values: + if suggestion.value in seen: + continue + seen.add(suggestion.value) + deduped.append(suggestion) + return deduped + + def _scan_workspace_paths(self) -> list[str]: + root = self._workspace_root + if not root.exists() or not root.is_dir(): + return [] + + suggestions: list[str] = [] + for current_root, dir_names, file_names in os.walk(root, topdown=True, followlinks=False): + dir_names[:] = sorted(name for name in dir_names if name not in _IGNORED_DIR_NAMES) + relative_root = Path(current_root).relative_to(root) + + for dir_name in dir_names: + relative = (relative_root / dir_name).as_posix() + suggestions.append(f"{relative}/") + if len(suggestions) >= _MAX_PATH_SUGGESTIONS: + return suggestions[:_MAX_PATH_SUGGESTIONS] + + for file_name in sorted(file_names): + relative = (relative_root / file_name).as_posix() + suggestions.append(relative) + if len(suggestions) >= _MAX_PATH_SUGGESTIONS: + return suggestions[:_MAX_PATH_SUGGESTIONS] + + return suggestions + + def _file_values(self) -> list[str]: + values = [value.strip() for value in self._file_suggestions() if value.strip()] + deduped: list[str] = [] + seen: set[str] = set() + for value in values: + if value in seen: + continue + seen.add(value) + deduped.append(value) + return deduped + + @staticmethod + def _mention_query(text: str) -> str | None: + return extract_mention_query(text) + + @staticmethod + def _is_subsequence(query: str, target: str) -> bool: + if not query: + return True + target_index = 0 + for char in query: + found = False + while target_index < len(target): + if target[target_index] == char: + found = True + target_index += 1 + break + target_index += 1 + if not found: + return False + return True + + def _path_match_score(self, path_value: str, query: str) -> int | None: + normalized = path_value.rstrip("/").lower() + basename = normalized.rsplit("/", maxsplit=1)[-1] + lowered_query = query.lower() + + if not lowered_query: + return 0 + if normalized == lowered_query or basename == lowered_query: + return 0 + if basename.startswith(lowered_query): + return 1 + if normalized.startswith(lowered_query): + return 2 + if lowered_query in basename: + return 3 + if lowered_query in normalized: + return 4 + if self._is_subsequence(lowered_query, basename): + return 5 + if self._is_subsequence(lowered_query, normalized): + return 6 + return None + + def _mention_matches(self, query: str) -> list[str]: + ranked: list[tuple[int, int, str]] = [] + for value in self._file_values(): + score = self._path_match_score(value, query) + if score is None: + continue + ranked.append((score, len(value), value)) + ranked.sort(key=lambda item: (item[0], item[1], item[2])) + return [value for _score, _length, value in ranked[:_MAX_COMPLETION_ITEMS]] + + @staticmethod + def _token_matches(token: str, query: str) -> bool: + if not query: + return True + if token.startswith(query): + return True + return query in token + + @staticmethod + def _command_aliases(command: dict[str, object]) -> list[str]: + aliases_raw = command.get("aliases", []) + if not isinstance(aliases_raw, list): + return [] + return [str(alias) for alias in aliases_raw if isinstance(alias, str)] + + def _match_score(self, command: dict[str, object], query: str) -> int | None: + name = str(command.get("name", "")) + aliases = self._command_aliases(command) + + if not query: + return 0 + if name == query: + return 0 + if name.startswith(query): + return 1 + for alias in aliases: + if alias == query: + return 2 + for alias in aliases: + if alias.startswith(query): + return 3 + if query in name: + return 4 + if any(query in alias for alias in aliases): + return 5 + return None + + def get_completions( + self, document: Document, complete_event: object + ) -> Iterable[Completion]: + _ = complete_event + text = document.text_before_cursor + + mention_query = self._mention_query(text) + if mention_query is not None: + for value in self._mention_matches(mention_query): + meta = "directory" if value.endswith("/") else "file" + yield Completion( + value, + start_position=-len(mention_query), + display=f"@{value}", + display_meta=meta, + ) + return + + if not text.startswith("/"): + return + + raw = text[1:] + if " " not in raw: + current = raw + ranked: list[tuple[int, str, dict[str, object]]] = [] + for command in self._commands: + score = self._match_score(command, current) + if score is None: + continue + name = str(command["name"]) + ranked.append((score, name, command)) + + ranked.sort(key=lambda item: (item[0], item[1])) + for _, _name, command in ranked: + name = str(command["name"]) + aliases = self._command_aliases(command) + alias_suffix = f" ({', '.join(aliases)})" if aliases else "" + description = str(command.get("description", "")) + display = f"/{name}{alias_suffix} - {description}" + meta = str(command.get("argument_hint", "") or description) + yield Completion( + name, + start_position=-len(current), + display=display, + display_meta=meta, + ) + return + + command_name, rest = raw.split(" ", 1) + command_name = command_name.strip() + canonical_name = command_name + for command in self._commands: + aliases = self._command_aliases(command) + if canonical_name == str(command["name"]) or canonical_name in aliases: + canonical_name = str(command["name"]) + break + + arg_prefix = rest.lstrip() + if canonical_name == "model": + for name in self._model_values(): + if self._token_matches(name, arg_prefix): + yield Completion( + name, + start_position=-len(arg_prefix), + display=name, + display_meta="model/profile name", + ) + elif canonical_name == "resume": + for suggestion in self._resume_values(): + if self._token_matches(suggestion.value, arg_prefix): + yield Completion( + suggestion.value, + start_position=-len(arg_prefix), + display=suggestion.display, + display_meta=suggestion.meta, + ) + + +def build_completer( + model_suggestions: Callable[[], list[str]] | None = None, + resume_suggestions: Callable[[], list[str | ResumeSuggestion]] | None = None, + file_suggestions: Callable[[], list[str]] | None = None, + workspace_root: str | None = None, + command_specs: list[dict[str, object]] | None = None, +) -> SlashCommandCompleter: + return SlashCommandCompleter( + model_suggestions=model_suggestions, + resume_suggestions=resume_suggestions, + file_suggestions=file_suggestions, + workspace_root=workspace_root, + command_specs=command_specs, + ) diff --git a/eagent/utils/cost.py b/eagent/utils/cost.py new file mode 100644 index 0000000..8021f82 --- /dev/null +++ b/eagent/utils/cost.py @@ -0,0 +1,40 @@ +"""Cost tracker helpers.""" + +from __future__ import annotations + +from eagent.core.types import CostTracker + + +def format_token_count(value: int) -> str: + if value >= 1_000_000: + return f"{value / 1_000_000:.1f}M" + if value >= 1_000: + return f"{value / 1_000:.1f}K" + return str(value) + + +def format_usd(value: float) -> str: + if value < 0.001: + return f"${value:.5f}" + if value < 0.01: + return f"${value:.4f}" + if value < 1: + return f"${value:.3f}" + return f"${value:.2f}" + + +def create_cost_tracker() -> CostTracker: + return CostTracker() + + +def summarize_cost(tracker: CostTracker, model_config) -> str: + in_str = format_token_count(tracker.total_input_tokens) + out_str = format_token_count(tracker.total_output_tokens) + cost_str = format_usd(tracker.total_cost_usd(model_config)) + parts = [f"Turn {tracker.turns}", f"{in_str} in / {out_str} out"] + if tracker.total_cache_read_tokens > 0 or tracker.total_cache_creation_tokens > 0: + cache_read = format_token_count(tracker.total_cache_read_tokens) + cache_write = format_token_count(tracker.total_cache_creation_tokens) + parts.append(f"cache: {cache_read} read / {cache_write} write") + parts.append(cost_str) + return " | ".join(parts) diff --git a/eagent/utils/format.py b/eagent/utils/format.py new file mode 100644 index 0000000..e46d30d --- /dev/null +++ b/eagent/utils/format.py @@ -0,0 +1,35 @@ +"""Terminal formatting helpers.""" + +from __future__ import annotations + + +def _wrap(code: str, text: str) -> str: + return f"\x1b[{code}m{text}\x1b[0m" + + +def blue(text: str) -> str: + return _wrap("34", text) + + +def green(text: str) -> str: + return _wrap("32", text) + + +def yellow(text: str) -> str: + return _wrap("33", text) + + +def red(text: str) -> str: + return _wrap("31", text) + + +def dim(text: str) -> str: + return _wrap("2", text) + + +def bold(text: str) -> str: + return _wrap("1", text) + + +def gold(text: str) -> str: + return _wrap("33;1", text) diff --git a/eagent/utils/process.py b/eagent/utils/process.py new file mode 100644 index 0000000..f1bd6ed --- /dev/null +++ b/eagent/utils/process.py @@ -0,0 +1,54 @@ +"""Async process helpers.""" + +from __future__ import annotations + +import asyncio +import contextlib +from dataclasses import dataclass + + +@dataclass +class ProcessResult: + stdout: str + stderr: str + returncode: int + + +async def run_process( + command: list[str], + cwd: str | None = None, + timeout: float | None = None, +) -> ProcessResult: + process = await asyncio.create_subprocess_exec( + *command, + cwd=cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout_b, stderr_b = await asyncio.wait_for(process.communicate(), timeout=timeout) + except asyncio.CancelledError: + with contextlib.suppress(ProcessLookupError): + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=2) + except TimeoutError: + with contextlib.suppress(ProcessLookupError): + process.kill() + await process.wait() + raise + except TimeoutError: + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=2) + except TimeoutError: + process.kill() + await process.wait() + raise + + return ProcessResult( + stdout=stdout_b.decode("utf-8", errors="replace"), + stderr=stderr_b.decode("utf-8", errors="replace"), + returncode=process.returncode or 0, + ) diff --git a/eagent/utils/streaming.py b/eagent/utils/streaming.py new file mode 100644 index 0000000..81ed280 --- /dev/null +++ b/eagent/utils/streaming.py @@ -0,0 +1,40 @@ +"""Streaming event helpers.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator + +from eagent.core.types import StreamEvent + + +async def collect_assistant_text(events: AsyncGenerator[StreamEvent, None]) -> str: + chunks: list[str] = [] + async for event in events: + if event["type"] == "assistant_text": + chunks.append(event["text"]) + return "".join(chunks) + + +def event_to_log_line(event: StreamEvent) -> str: + event_type = event.get("type", "unknown") + if event_type == "assistant_text": + return event.get("text", "") + if event_type == "tool_start": + return f"[tool:start] {event.get('tool_name')}" + if event_type == "tool_result": + status = "error" if event.get("is_error") else "ok" + return f"[tool:{status}] {event.get('tool_name')}: {event.get('result', '')[:120]}" + if event_type == "compact": + return f"[compact] {event.get('old_tokens')} -> {event.get('new_tokens')}" + if event_type == "usage": + usage = event.get("usage") + if usage: + return ( + f"[usage] in={usage.input_tokens} out={usage.output_tokens} " + f"cache_r={usage.cache_read_tokens} cache_w={usage.cache_creation_tokens}" + ) + if event_type == "error": + return f"[error] {event.get('error')}" + if event_type == "hook_debug": + return f"[hook] {event.get('text', '')}" + return f"[{event_type}]" diff --git a/env.py b/env.py index d240476..9ced2a6 100644 --- a/env.py +++ b/env.py @@ -30,7 +30,6 @@ import argparse import logging import platform -import json script_path = os.path.abspath(__file__) mpath = os.path.dirname(script_path) @@ -78,6 +77,14 @@ def init_argparse(): cmd_menuconfig.add_parser(subs) cmd_package.add_parser(subs) cmd_sdk.add_parser(subs) + agent_parser = subs.add_parser( + 'agent', + help='Start EnvAgent AI assistant.', + description='Start EnvAgent AI assistant.', + add_help=False, + ) + agent_parser.add_argument('agent_args', nargs=argparse.REMAINDER) + agent_parser.set_defaults(func=agent) return parser @@ -221,6 +228,10 @@ def main(): export_environment_variable() init_logger(get_env_root()) + if len(sys.argv) > 1 and sys.argv[1] == 'agent': + agent() + return + parser = init_argparse() args = parser.parse_args() @@ -250,5 +261,17 @@ def system(): exec_arg('system') +def agent(args=None): + from eagent.cli import main as agent_main + from eagent.reload import ReloadArgs + + argv = getattr(args, 'agent_args', None) + if argv is None: + argv = sys.argv[2:] if len(sys.argv) > 1 and sys.argv[1] == 'agent' else sys.argv[1:] + ReloadArgs.remember([sys.argv[0], 'agent', *argv]) + sys.argv = [sys.argv[0], *argv] + agent_main(standalone_mode=True) + + if __name__ == '__main__': main() diff --git a/env.sh b/env.sh index 5c5280f..6870730 100644 --- a/env.sh +++ b/env.sh @@ -1 +1,19 @@ -export PATH=~/.env/tools/scripts:$PATH +#!/usr/bin/env bash + +VENV_ROOT="$HOME/.env/.venv" +ENV_SCRIPTS="$HOME/.env/tools/scripts" + +if [ ! -d "$VENV_ROOT" ]; then + echo "Create Python venv for RT-Thread..." + python3 -m venv "$VENV_ROOT" + # shellcheck source=/dev/null + source "$VENV_ROOT/bin/activate" + + python -m pip install --upgrade pip + pip install "$ENV_SCRIPTS" +else + # shellcheck source=/dev/null + source "$VENV_ROOT/bin/activate" +fi + +export PATH="$VENV_ROOT/bin:$ENV_SCRIPTS:$PATH" diff --git a/install_macos.sh b/install_macos.sh index 6098479..b30c27e 100755 --- a/install_macos.sh +++ b/install_macos.sh @@ -61,6 +61,13 @@ if ! [[ $($RTT_PYTHON -m pip list | grep requests) ]]; then $RTT_PYTHON -m pip install requests fi +for pypkg in psutil pyyaml anthropic click prompt_toolkit rich httpx pydantic; do + if ! [[ $($RTT_PYTHON -m pip list | grep $pypkg) ]]; then + echo "Installing $pypkg." + $RTT_PYTHON -m pip install $pypkg + fi +done + if ! [ -x "$(command -v arm-none-eabi-gcc)" ]; then echo "Installing GNU Arm Embedded Toolchain." brew install gnu-arm-embedded diff --git a/install_ubuntu.sh b/install_ubuntu.sh index 2d27e97..7388be4 100755 --- a/install_ubuntu.sh +++ b/install_ubuntu.sh @@ -2,7 +2,7 @@ sudo apt-get update sudo apt-get -qq install python3 python3-pip gcc git libncurses5-dev -y -pip install scons requests tqdm kconfiglib pyyaml +pip install scons requests psutil tqdm kconfiglib pyyaml anthropic click prompt_toolkit rich httpx pydantic url=https://raw.githubusercontent.com/RT-Thread/env/master/touch_env.sh if [ $1 ] && [ $1 = --gitee ]; then diff --git a/install_windows.ps1 b/install_windows.ps1 index 8ed3b77..8658c65 100644 --- a/install_windows.ps1 +++ b/install_windows.ps1 @@ -115,6 +115,16 @@ if (!$?) { echo "psutil module has installed. Jump this step." } +foreach ($pypkg in ("pyyaml", "anthropic", "click", "prompt_toolkit", "rich", "httpx", "pydantic")) { + cmd /c $RTT_PYTHON -m pip list -i $PIP_SOURCE --trusted-host $PIP_HOST | findstr $pypkg | Out-Null + if (!$?) { + echo "Installing $pypkg module." + cmd /c $RTT_PYTHON -m pip install $pypkg -i $PIP_SOURCE --trusted-host $PIP_HOST + } else { + echo "$pypkg module has installed. Jump this step." + } +} + $url = "https://raw.githubusercontent.com/RT-Thread/env/master/touch_env.ps1" if ($args[0] -eq "--gitee") { $url = "https://gitee.com/RT-Thread-Mirror/env/raw/master/touch_env.ps1" diff --git a/setup.py b/setup.py index cca760c..20755cb 100644 --- a/setup.py +++ b/setup.py @@ -15,24 +15,45 @@ 'Github repository': 'https:/github.com/rt-thread/env.git', 'User guide': 'https:/github.com/rt-thread/env.git', }, - python_requires='>=3.6', + python_requires='>=3.11', install_requires=[ 'SCons>=4.0.0', 'requests', 'psutil', 'tqdm', 'kconfiglib', + 'anthropic>=0.39.0', + 'click>=8.1.7', + 'prompt_toolkit>=3.0.47', + 'rich>=13.7.1', + 'httpx>=0.27.0', + 'pydantic>=2.8.2', + 'pyyaml>=6.0.2', 'windows-curses; platform_system=="Windows"', ], packages=[ 'env', 'env.cmds', 'env.cmds.cmd_package', + 'eagent', + 'eagent.commands', + 'eagent.context', + 'eagent.core', + 'eagent.files', + 'eagent.hooks', + 'eagent.mcp', + 'eagent.permissions', + 'eagent.prompt', + 'eagent.skills', + 'eagent.tools', + 'eagent.tui', + 'eagent.utils', ], package_dir={ 'env': '.', 'env.cmds': 'cmds', 'env.cmds.cmd_package': 'cmds/cmd_package', + 'eagent': 'eagent', }, package_data={'': ['*.*']}, exclude_package_data={'': ['MANIFEST.in']}, @@ -44,6 +65,8 @@ 'pkgs=env.env:pkgs', 'sdk=env.env:sdk', 'system=env.env:system', + 'agent=eagent.cli:main', + 'eagent=eagent.cli:main', ] }, ) diff --git a/touch_env.sh b/touch_env.sh index c35dc4d..7a9dee9 100755 --- a/touch_env.sh +++ b/touch_env.sh @@ -29,5 +29,6 @@ if ! [ -d $env_dir ]; then echo 'source "$PKGS_DIR/packages/Kconfig"' >$env_dir/packages/Kconfig git clone $SDK_URL $env_dir/packages/sdk --depth=1 git clone $ENV_URL $env_dir/tools/scripts --depth=1 - echo -e 'export PATH=`python3 -m site --user-base`/bin:$HOME/.env/tools/scripts:$PATH\nexport RTT_EXEC_PATH=/usr/bin' >$env_dir/env.sh + cp $env_dir/tools/scripts/env.sh $env_dir/env.sh + echo 'export RTT_EXEC_PATH=/usr/bin' >>$env_dir/env.sh fi