Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 39 additions & 114 deletions packages/kcastle/src/kcastle/castle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,22 @@
import asyncio
import signal
from pathlib import Path
from typing import Any

from kagent import Agent, Trace
from kai import ProviderBase, Tool
from kai import Tool

from kcastle.channels import Channel
from kcastle.channels.cli import CLIChannel
from kcastle.channels.telegram import TelegramChannel
from kcastle.config import CastleConfig, load_config
from kcastle.log import logger
from kcastle.providers import create_provider
from kcastle.providers import ModelManager, create_provider
from kcastle.session.manager import SessionManager
from kcastle.skills.manager import SkillManager, find_project_root
from kcastle.skills.skill import Skill
from kcastle.skills.skill import render_compact_skills
from kcastle.tools import create_builtin_tools


def _create_provider(config: CastleConfig) -> ProviderBase:
"""Create a kai Provider from the active provider config."""
provider_config = config.active_provider_config()
return create_provider(provider_config)


def _build_system_prompt(config: CastleConfig, skill_prompts: str = "") -> str:
"""Assemble the system prompt from composable blocks."""
from kcastle.prompts import (
Expand Down Expand Up @@ -70,20 +63,17 @@ def __init__(
session_manager: SessionManager,
skill_manager: SkillManager,
channels: list[Channel],
provider: ProviderBase,
model_manager: ModelManager,
system_prompt: str,
skill_tools: list[Tool],
) -> None:
self._config = config
self._session_manager = session_manager
self._skill_manager = skill_manager
self._channels: list[Channel] = channels
self._provider = provider
self._model_manager = model_manager
self._system_prompt = system_prompt
self._skill_tools = skill_tools
self._active_provider_name = config.default_provider
self._active_model = config.default_model
self._session_models: dict[str, tuple[str, str]] = {}

@property
def config(self) -> CastleConfig:
Expand All @@ -97,89 +87,31 @@ def session_manager(self) -> SessionManager:
def skill_manager(self) -> SkillManager:
return self._skill_manager

@property
def model_manager(self) -> ModelManager:
return self._model_manager

@property
def active_provider_name(self) -> str:
return self._active_provider_name
return self._model_manager.active_provider_name

@property
def active_model(self) -> str:
return self._active_model
return self._model_manager.active_model

def get_active_model(self, session_id: str | None = None) -> tuple[str, str]:
"""Return active ``(provider_name, model_id)``.

If ``session_id`` has an override, returns that override; otherwise
returns the global default runtime model.
Delegates to :class:`~kcastle.providers.ModelManager`.
"""
if session_id is not None:
if session_id in self._session_models:
return self._session_models[session_id]

loaded = self._session_manager.get(session_id)
if loaded is not None and loaded.model_override is not None:
override = loaded.model_override
self._session_models[session_id] = override
try:
loaded.agent.replace_llm(self._build_provider(*override))
except (ValueError, RuntimeError):
logger.warning(
"Failed to restore session %s model override %s / %s",
session_id,
override[0],
override[1],
)
return override

return (self._active_provider_name, self._active_model)

def prepare_user_input(self, user_input: str) -> str:
"""Augment user input with explicitly hinted skill instructions.

Bub-style progressive disclosure:
- compact skill metadata is always present in system prompt
- full skill body is injected only when user references ``$skill-name``
"""
hints = Skill.extract_hints(user_input)
if not hints:
return user_input

expanded: list[Any] = []
for hint in hints:
skill = self._skill_manager.get_skill(hint)
if skill is None:
continue
expanded.append(skill)

expansion_block = Skill.render_expanded(expanded)
if not expansion_block:
return user_input
return f"{user_input}\n\n{expansion_block}"

def _build_provider(self, provider_name: str, model_id: str) -> ProviderBase:
"""Validate and build a provider instance for ``provider_name/model_id``."""
provider_config = self._config.provider_config(provider_name, model_id)
return create_provider(provider_config)

def _apply_provider_to_session(self, session_id: str, provider: ProviderBase) -> None:
"""Hot-swap provider for one loaded session."""
session = self._session_manager.get(session_id)
if session is None:
raise KeyError(f"Session {session_id!r} is not loaded")
session.agent.replace_llm(provider)
return self._model_manager.get_active_model(session_id)

def available_models(self) -> list[tuple[str, str]]:
"""Return ``(provider_name, model_id)`` pairs for all active models.

Only includes providers whose API key is non-empty after env
expansion.
Delegates to :class:`~kcastle.providers.ModelManager`.
"""
result: list[tuple[str, str]] = []
for pname, pcfg in self._config.providers.items():
if not pcfg.api_key:
continue
for m in pcfg.active_models():
result.append((pname, m.id))
return result
return self._model_manager.available_models()

def switch_model(
self,
Expand All @@ -188,33 +120,18 @@ def switch_model(
*,
session_id: str,
) -> None:
"""Switch the active provider and model at runtime.
"""Switch the active provider and model for a single loaded session.

Only updates the specified loaded session.
Delegates to :class:`~kcastle.providers.ModelManager`.
"""
current_provider, current_model = self.get_active_model(session_id)
logger.info(
"Switching session %s model: %s / %s -> %s / %s",
session_id,
current_provider,
current_model,
provider_name,
model_id,
)
self._model_manager.switch_model(provider_name, model_id, session_id=session_id)

provider = self._build_provider(provider_name, model_id)
self._apply_provider_to_session(session_id, provider)
session = self._session_manager.get(session_id)
if session is None:
raise KeyError(f"Session {session_id!r} is not loaded")
session.set_model_override(provider_name, model_id)
self._session_models[session_id] = (provider_name, model_id)
logger.info(
"Switched session %s model to %s / %s",
session_id,
provider_name,
model_id,
)
def prepare_user_input(self, user_input: str) -> str:
"""Augment user input with explicitly hinted skill instructions.

Delegates to :meth:`~kcastle.skills.SkillManager.expand_hints`.
"""
return self._skill_manager.expand_hints(user_input)

@classmethod
def create(
Expand Down Expand Up @@ -245,14 +162,15 @@ def create(
)
skill_manager.discover()

provider = _create_provider(config)
provider_config = config.active_provider_config()
provider = create_provider(provider_config)

all_skills = skill_manager.all_skills()
skill_tools = create_builtin_tools(
workspace=Path.cwd(),
skill_manager=skill_manager,
)
skill_prompts = Skill.render_compact(all_skills)
skill_prompts = render_compact_skills(all_skills)

system_prompt = _build_system_prompt(config, skill_prompts)

Expand Down Expand Up @@ -282,15 +200,22 @@ def agent_factory(trace: Trace) -> Agent:
max_turns=config.max_turns,
)

session_manager = SessionManager(
sessions_dir=config.sessions_dir,
agent_factory=agent_factory,
)

model_manager = ModelManager(
config=config,
session_manager=session_manager,
)

return cls(
config=config,
session_manager=SessionManager(
sessions_dir=config.sessions_dir,
agent_factory=agent_factory,
),
session_manager=session_manager,
skill_manager=skill_manager,
channels=channels,
provider=provider,
model_manager=model_manager,
system_prompt=system_prompt,
skill_tools=skill_tools,
)
Expand Down
2 changes: 2 additions & 0 deletions packages/kcastle/src/kcastle/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
parse_models,
parse_providers,
)
from kcastle.providers.model_manager import ModelManager

__all__ = [
"ModelConfig",
"ModelManager",
"ProviderConfig",
"ProviderEntry",
"ProviderRegistry",
Expand Down
Loading