Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 192 additions & 26 deletions runtime_manager/worker_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import traceback
from pathlib import Path
from typing import Any
from urllib.parse import urlparse

_OUTPUT_LOCK = threading.Lock()
_AGENT_HOLDER: dict[str, Any] = {"agent": None}
Expand All @@ -32,8 +33,18 @@
_SENSITIVE_FLAG_VALUE_PATTERN = re.compile(
r"(?i)(--(?:token|password|secret|key))\s+(?:\"[^\"]*\"|'[^']*'|\S+)"
)
_SENSITIVE_STATUS_VALUE_PATTERN = re.compile(
r"(?i)\b(api[_-]?key|authorization|token|password|secret)(\s*[:=]\s*)([^\s,;.]+)"
)
_BEARER_STATUS_VALUE_PATTERN = re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._~+/=-]+")
_INTERNAL_USER_PATH_PATTERN = re.compile(r"/opt/data/users/[^\s'\"]+")
_WHITESPACE_PATTERN = re.compile(r"\s+")
_COMPRESSION_SUMMARY_ERROR_PATTERN = re.compile(
r"(?is)compression summary failed:\s*(?P<error>.*?)(?:\.\s*inserted\b|$)"
)
_COMPRESSION_ABORT_ERROR_PATTERN = re.compile(
r"(?is)compression aborted:\s*(?P<error>.*?)(?:\.\s*no messages\b|$)"
)
_TOOL_OUTPUT_PREVIEW_LIMIT = 1200
_OUTPUT_TRUNCATION_MARKERS = (
"output too long",
Expand Down Expand Up @@ -137,15 +148,27 @@ def on_tool_progress(
)

def on_status(kind: str, message: str) -> None:
timestamp = time.time()
emit(
{
"event": "status.message",
"run_id": run_id,
"timestamp": time.time(),
"timestamp": timestamp,
"kind": kind,
"message": message,
}
)
compression_event = _compression_event_from_status(
kind,
message,
provider=provider,
model=model,
base_url=base_url,
)
if compression_event:
compression_event["run_id"] = run_id
compression_event["timestamp"] = timestamp
emit(compression_event)

def approval_notify(data: dict[str, Any]) -> None:
event = dict(data or {})
Expand Down Expand Up @@ -348,27 +371,7 @@ def on_thinking(message: str | None) -> None:
session_key=approval_session_key,
)
register_gateway_notify(approval_session_key, approval_notify)
llm_config = request.get("llm_config")
if not isinstance(llm_config, dict):
llm_config = {}
model = (
request.get("model")
or llm_config.get("model")
or llm_config.get("name")
or llm_config.get("default")
or ""
)
api_key = _first_present(request.get("api_key"), llm_config.get("api_key"), llm_config.get("apiKey"))
base_url = _first_present(
request.get("base_url"),
request.get("baseURL"),
llm_config.get("base_url"),
llm_config.get("baseURL"),
)
provider = _normalize_agent_provider(
_first_present(request.get("provider"), llm_config.get("provider")),
base_url=base_url,
)
runtime_llm_config = _resolve_runtime_llm_config(request)

system_prompt = _compose_effective_system_prompt(
request,
Expand All @@ -377,10 +380,11 @@ def on_thinking(message: str | None) -> None:
)

agent = AIAgent(
model=str(model or ""),
provider=provider,
api_key=api_key,
base_url=base_url,
model=runtime_llm_config["model"],
provider=runtime_llm_config["provider"],
api_key=runtime_llm_config["api_key"],
base_url=runtime_llm_config["base_url"],
max_tokens=runtime_llm_config.get("max_tokens"),
session_id=session_id,
session_db=SessionDB(),
quiet_mode=True,
Expand All @@ -404,6 +408,7 @@ def on_thinking(message: str | None) -> None:
ephemeral_system_prompt=system_prompt,
max_iterations=int(request.get("max_iterations") or 90),
)
_apply_runtime_llm_config_to_agent(agent, runtime_llm_config)
_AGENT_HOLDER["agent"] = agent

emit({"event": "run.running", "run_id": run_id, "timestamp": time.time()})
Expand Down Expand Up @@ -483,6 +488,93 @@ def _first_present(*values: Any) -> Any:
return None


def _resolve_runtime_llm_config(request: dict[str, Any]) -> dict[str, Any]:
llm_config = request.get("llm_config")
if not isinstance(llm_config, dict):
llm_config = {}

model = (
request.get("model")
or llm_config.get("model")
or llm_config.get("name")
or llm_config.get("default")
or ""
)
api_key = _first_present(request.get("api_key"), llm_config.get("api_key"), llm_config.get("apiKey"))
base_url = _first_present(
request.get("base_url"),
request.get("baseURL"),
llm_config.get("base_url"),
llm_config.get("baseURL"),
)
provider = _normalize_agent_provider(
_first_present(request.get("provider"), llm_config.get("provider")),
base_url=base_url,
)
if base_url and api_key:
provider = "custom"

resolved = {
"model": str(model or ""),
"provider": provider,
"api_key": api_key,
"base_url": base_url,
}
max_tokens = _coerce_positive_int(
_first_present(
request.get("max_tokens"),
request.get("maxTokens"),
llm_config.get("max_tokens"),
llm_config.get("maxTokens"),
)
)
if max_tokens > 0:
resolved["max_tokens"] = max_tokens

context_length = _coerce_positive_int(
_first_present(
request.get("context_length"),
request.get("contextLength"),
request.get("context_window"),
request.get("contextWindow"),
llm_config.get("context_length"),
llm_config.get("contextLength"),
llm_config.get("context_window"),
llm_config.get("contextWindow"),
)
)
if context_length > 0:
resolved["context_length"] = context_length

return resolved


def _apply_runtime_llm_config_to_agent(agent: Any, runtime_llm_config: dict[str, Any]) -> None:
context_length = runtime_llm_config.get("context_length")
if context_length is None:
return

setattr(agent, "_config_context_length", context_length)

session_model_config = getattr(agent, "_session_init_model_config", None)
if isinstance(session_model_config, dict):
session_model_config["context_length"] = context_length

compressor = getattr(agent, "context_compressor", None)
update_model = getattr(compressor, "update_model", None)
if not callable(update_model):
return

update_model(
model=str(getattr(agent, "model", None) or runtime_llm_config.get("model") or ""),
context_length=context_length,
base_url=str(getattr(agent, "base_url", None) or runtime_llm_config.get("base_url") or ""),
api_key=getattr(agent, "api_key", None) or runtime_llm_config.get("api_key") or "",
provider=str(getattr(agent, "provider", None) or runtime_llm_config.get("provider") or ""),
api_mode=str(getattr(agent, "api_mode", None) or runtime_llm_config.get("api_mode") or ""),
)


def _normalize_agent_provider(provider: Any, *, base_url: Any = None) -> str | None:
if provider is None:
return None
Expand Down Expand Up @@ -512,6 +604,80 @@ def _normalize_agent_provider(provider: Any, *, base_url: Any = None) -> str | N
return value


def _base_url_host(value: Any) -> str:
text = str(value or "").strip()
if not text:
return ""
try:
parsed = urlparse(text)
except Exception:
return ""
return parsed.netloc or parsed.path.split("/", 1)[0]


def _safe_status_event_text(value: Any, *, limit: int = 512) -> str:
text = str(value or "")
text = _SENSITIVE_STATUS_VALUE_PATTERN.sub(r"\1\2<redacted>", text)
text = _BEARER_STATUS_VALUE_PATTERN.sub("Bearer <redacted>", text)
text = _INTERNAL_USER_PATH_PATTERN.sub("/opt/data/users/<redacted>", text)
if len(text) > limit:
return text[:limit] + "..."
return text


def _compression_event_from_status(
kind: Any,
message: Any,
*,
provider: Any = None,
model: Any = None,
base_url: Any = None,
) -> dict[str, Any] | None:
text = str(message or "")
normalized = text.lower()
event_name = "context.compression.warning"
reason = ""
fallback = False
abort = False
summary_error = ""

if "compression summary failed" in normalized:
reason = "summary_failed"
fallback = True
match = _COMPRESSION_SUMMARY_ERROR_PATTERN.search(text)
summary_error = match.group("error").strip() if match else ""
elif "compression aborted" in normalized:
event_name = "context.compression.failed"
reason = "compression_aborted"
abort = True
match = _COMPRESSION_ABORT_ERROR_PATTERN.search(text)
summary_error = match.group("error").strip() if match else ""
elif "no auxiliary llm provider configured" in normalized:
reason = "no_auxiliary_provider"
fallback = True
summary_error = "no auxiliary LLM provider configured"
elif "session compressed" in normalized and "accuracy may degrade" in normalized:
reason = "repeated_compression"
else:
return None

event: dict[str, Any] = {
"event": event_name,
"reason": reason,
"sourceEventKind": str(kind or ""),
"message": _safe_status_event_text(text),
"fallback": fallback,
"abort": abort,
"provider": str(provider or ""),
"providerSource": "configured" if provider else "auto",
"model": str(model or ""),
"baseURLHost": _base_url_host(base_url),
}
if summary_error:
event["summaryError"] = _safe_status_event_text(summary_error)
return event


def _json_preview(value: Any, *, limit: int = 500) -> str:
try:
if isinstance(value, str):
Expand Down
Loading
Loading