Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 253 additions & 28 deletions AFL/double_agent/AgentDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,153 @@
import json
import inspect
import importlib
import pkgutil
from typing import Optional, Dict, Any, List, get_type_hints, Union
import copy
import hashlib
import logging
import os
import tempfile
import threading
import time
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple, get_type_hints, Union

import xarray as xr

from AFL.automation.APIServer.Driver import Driver # type: ignore
from AFL.automation.shared.utilities import mpl_plot_to_bytes,xarray_to_bytes
try:
from AFL.automation.APIServer.Driver import Driver # type: ignore
from AFL.automation.shared.utilities import mpl_plot_to_bytes, xarray_to_bytes
except ModuleNotFoundError as exc:
# Allow unit tests to import this module in environments where AFL-automation
# is not installed. Runtime server behavior still requires AFL-automation.
if exc.name and exc.name.startswith("AFL.automation"):
class Driver: # type: ignore[override]
@staticmethod
def unqueued(*args, **kwargs):
def decorator(func):
return func
return decorator

@staticmethod
def queued(*args, **kwargs):
def decorator(func):
return func
return decorator

def __init__(self, *args, **kwargs):
pass

def gather_defaults(self):
return getattr(self, "defaults", {})

def mpl_plot_to_bytes(*args, **kwargs): # type: ignore[no-redef]
raise RuntimeError("mpl_plot_to_bytes requires AFL-automation to be installed.")

def xarray_to_bytes(*args, **kwargs): # type: ignore[no-redef]
raise RuntimeError("xarray_to_bytes requires AFL-automation to be installed.")
else:
raise
from AFL.double_agent.Pipeline import Pipeline
from AFL.double_agent.PipelineOp import PipelineOp
from AFL.double_agent.util import listify

from importlib.resources import files
from jinja2 import Template

LOGGER = logging.getLogger(__name__)
_DISCOVERY_LOCK = threading.Lock()
_PIPELINE_OPS_MEM_CACHE: Optional[Dict[str, Any]] = None


def _cache_path() -> pathlib.Path:
env_path = os.environ.get("AFL_PIPELINE_OPS_CACHE_PATH")
if env_path:
return pathlib.Path(env_path).expanduser()
return pathlib.Path.home() / ".cache" / "afl-double-agent" / "pipeline_ops_manifest.json"


def _candidate_module_files() -> List[pathlib.Path]:
module_dir = pathlib.Path(__file__).parent
excluded = {
"__init__.py",
"_version.py",
"AgentDriver.py",
"util.py",
}

module_files: List[pathlib.Path] = []
for path in sorted(module_dir.glob("*.py")):
if path.name in excluded or path.name.startswith("_"):
continue

# Cheap pre-filter to avoid importing modules that cannot define PipelineOps.
try:
if "PipelineOp" not in path.read_text(encoding="utf-8"):
continue
except Exception:
# If the pre-filter fails, keep the module candidate for safety.
pass
module_files.append(path)
return module_files


def _module_signature(module_files: List[pathlib.Path]) -> str:
hasher = hashlib.sha256()
for path in module_files:
stat = path.stat()
hasher.update(str(path).encode("utf-8"))
hasher.update(str(stat.st_mtime_ns).encode("utf-8"))
hasher.update(str(stat.st_size).encode("utf-8"))
return hasher.hexdigest()


def _parse_strict_flag(value: Any) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}


def _load_disk_cache(expected_signature: str) -> Optional[Dict[str, Any]]:
cache_file = _cache_path()
if not cache_file.exists():
return None
try:
cached = json.loads(cache_file.read_text(encoding="utf-8"))
except Exception:
return None

if cached.get("signature") != expected_signature:
return None
if "ops" not in cached or "warnings" not in cached or "generated_at" not in cached:
return None
return cached


def _save_disk_cache(payload: Dict[str, Any]) -> None:
cache_file = _cache_path()
cache_file.parent.mkdir(parents=True, exist_ok=True)

fd, tmp_name = tempfile.mkstemp(prefix="pipeline_ops_", suffix=".json", dir=str(cache_file.parent))
try:
with os.fdopen(fd, "w", encoding="utf-8") as tmp:
json.dump(payload, tmp)
pathlib.Path(tmp_name).replace(cache_file)
finally:
try:
pathlib.Path(tmp_name).unlink(missing_ok=True)
except Exception:
pass


def _build_warning(module_name: str, stage: str, error: Exception) -> Dict[str, str]:
return {
"module": module_name,
"stage": stage,
"error_type": type(error).__name__,
"message": str(error),
}


def _get_parameter_types(cls) -> Dict[str, str]:
"""Extract parameter types from class constructor type annotations.
Expand Down Expand Up @@ -81,32 +214,50 @@ def _get_parameter_types(cls) -> Dict[str, str]:
return param_types


def _collect_pipeline_ops() -> List[Dict[str, Any]]:
def _collect_pipeline_ops(module_files: List[pathlib.Path], strict: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, str]]]:
"""Gather metadata for all available :class:`PipelineOp` subclasses."""
import logging
logger = logging.getLogger(__name__)

ops: List[Dict[str, Any]] = []
package = importlib.import_module("AFL.double_agent")
for modinfo in pkgutil.iter_modules(package.__path__):
module_name = f"{package.__name__}.{modinfo.name}"
warnings: List[Dict[str, str]] = []

for module_path in module_files:
module_name = f"AFL.double_agent.{module_path.stem}"
try:
module = importlib.import_module(module_name)
except Exception as e:
msg = f"Skipping module '{module_name}': failed to import ({type(e).__name__}: {e})"
print(msg)
logger.warning(msg)
warning = _build_warning(module_name, "import", e)
warnings.append(warning)
LOGGER.warning(
"Skipping module '%s': failed to import (%s: %s)",
module_name,
type(e).__name__,
e,
)
if strict:
raise RuntimeError(
f"PipelineOp discovery failed while importing '{module_name}': {type(e).__name__}: {e}"
) from e
continue

try:
members = inspect.getmembers(module, inspect.isclass)
except Exception as e:
msg = f"Skipping module '{module_name}': failed to inspect members ({type(e).__name__}: {e})"
print(msg)
logger.warning(msg)
warning = _build_warning(module_name, "inspect", e)
warnings.append(warning)
LOGGER.warning(
"Skipping module '%s': failed to inspect members (%s: %s)",
module_name,
type(e).__name__,
e,
)
if strict:
raise RuntimeError(
f"PipelineOp discovery failed while inspecting '{module_name}': {type(e).__name__}: {e}"
) from e
continue

for name, obj in members:
if obj.__module__ != module.__name__:
continue
try:
if not (issubclass(obj, PipelineOp) and obj is not PipelineOp):
continue
Expand Down Expand Up @@ -160,18 +311,91 @@ def _collect_pipeline_ops() -> List[Dict[str, Any]]:
}
)
except Exception as e:
msg = f"Skipping PipelineOp '{name}' from '{module.__name__}': failed to extract metadata ({type(e).__name__}: {e})"
print(msg)
logger.warning(msg)
warning = _build_warning(module_name, "metadata", e)
warnings.append(warning)
LOGGER.warning(
"Skipping PipelineOp '%s' from '%s': failed to extract metadata (%s: %s)",
name,
module.__name__,
type(e).__name__,
e,
)
if strict:
raise RuntimeError(
f"PipelineOp discovery failed for class '{name}' in '{module_name}': {type(e).__name__}: {e}"
) from e
continue

ops.sort(key=lambda o: o["name"])
return ops
return ops, warnings


def get_pipeline_ops(strict: bool = False) -> Dict[str, Any]:
"""Return metadata describing available pipeline operations with cache metadata."""
start = time.perf_counter()
module_files = _candidate_module_files()
signature = _module_signature(module_files)

if not strict:
with _DISCOVERY_LOCK:
global _PIPELINE_OPS_MEM_CACHE

if _PIPELINE_OPS_MEM_CACHE and _PIPELINE_OPS_MEM_CACHE.get("signature") == signature:
result = copy.deepcopy(_PIPELINE_OPS_MEM_CACHE)
result["cache"]["source"] = "memory"
result["cache"]["duration_ms"] = int((time.perf_counter() - start) * 1000)
result.pop("signature", None)
return result

disk_cached = _load_disk_cache(signature)
if disk_cached is not None:
result = {
"ops": disk_cached["ops"],
"warnings": disk_cached["warnings"],
"cache": {
"source": "disk",
"generated_at": disk_cached["generated_at"],
"signature": disk_cached["signature"],
"duration_ms": int((time.perf_counter() - start) * 1000),
},
"signature": disk_cached["signature"],
}
_PIPELINE_OPS_MEM_CACHE = copy.deepcopy(result)
result.pop("signature", None)
return result

ops, warnings = _collect_pipeline_ops(module_files, strict=strict)
generated_at = datetime.now(timezone.utc).isoformat()
duration_ms = int((time.perf_counter() - start) * 1000)

result = {
"ops": ops,
"warnings": warnings,
"cache": {
"source": "fresh",
"generated_at": generated_at,
"signature": signature,
"duration_ms": duration_ms,
},
"signature": signature,
}

if not strict:
cache_payload = {
"ops": ops,
"warnings": warnings,
"generated_at": generated_at,
"signature": signature,
}
with _DISCOVERY_LOCK:
_PIPELINE_OPS_MEM_CACHE = copy.deepcopy(result)
try:
_save_disk_cache(cache_payload)
except Exception as exc:
LOGGER.warning("Failed to write pipeline ops cache: %s: %s", type(exc).__name__, exc)

def get_pipeline_ops() -> List[Dict[str, Any]]:
"""Return metadata describing available pipeline operations."""
return _collect_pipeline_ops()
result.pop("signature", None)
return result


def build_pipeline_from_ops(ops: List[Dict[str, Any]], name: str = "Pipeline") -> Dict[str, Any]:
Expand Down Expand Up @@ -532,7 +756,8 @@ def test_fetch_entry(self, entry_id: str = None, **kwargs):
@Driver.unqueued()
def pipeline_ops(self, **kwargs):
"""Return metadata for available PipelineOps."""
return get_pipeline_ops()
strict = _parse_strict_flag(kwargs.get("strict"))
return get_pipeline_ops(strict=strict)

@Driver.unqueued()
def current_pipeline(self, **kwargs):
Expand Down
Loading
Loading