From 4596c27490a06901a14ef6e2586b80aba9034d7c Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Thu, 18 Sep 2025 11:09:59 -0400 Subject: [PATCH] Add LM-Inline provider and unit tests --- providers.d/inline/eval/trustyai_lmeval.yaml | 5 + providers.d/remote/eval/trustyai_lmeval.yaml | 2 +- run-inline.yaml | 30 + src/llama_stack_provider_lmeval/__init__.py | 54 -- src/llama_stack_provider_lmeval/config.py | 3 + .../inline/__init__.py | 52 ++ .../inline/lmeval.py | 620 ++++++++++++++++++ .../inline/provider.py | 13 + .../remote/__init__.py | 54 ++ .../{ => remote}/lmeval.py | 6 +- .../{ => remote}/provider.py | 2 +- tests/test_base_url_handling.py | 4 +- tests/test_env_vars.py | 10 +- tests/test_lmeval.py | 82 +-- tests/test_lmeval_inline.py | 398 +++++++++++ tests/test_namespace.py | 102 +-- 16 files changed, 1279 insertions(+), 158 deletions(-) create mode 100644 providers.d/inline/eval/trustyai_lmeval.yaml create mode 100644 run-inline.yaml create mode 100644 src/llama_stack_provider_lmeval/inline/__init__.py create mode 100644 src/llama_stack_provider_lmeval/inline/lmeval.py create mode 100644 src/llama_stack_provider_lmeval/inline/provider.py create mode 100644 src/llama_stack_provider_lmeval/remote/__init__.py rename src/llama_stack_provider_lmeval/{ => remote}/lmeval.py (99%) rename src/llama_stack_provider_lmeval/{ => remote}/provider.py (88%) create mode 100644 tests/test_lmeval_inline.py diff --git a/providers.d/inline/eval/trustyai_lmeval.yaml b/providers.d/inline/eval/trustyai_lmeval.yaml new file mode 100644 index 0000000..a84db99 --- /dev/null +++ b/providers.d/inline/eval/trustyai_lmeval.yaml @@ -0,0 +1,5 @@ +module: llama_stack_provider_lmeval.inline +config_class: llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig +pip_packages: ["lm-eval"] +api_dependencies: ["inference", "files"] +optional_api_dependencies: [] diff --git a/providers.d/remote/eval/trustyai_lmeval.yaml b/providers.d/remote/eval/trustyai_lmeval.yaml index e1a0beb..12f5042 100644 --- a/providers.d/remote/eval/trustyai_lmeval.yaml +++ b/providers.d/remote/eval/trustyai_lmeval.yaml @@ -2,6 +2,6 @@ adapter: adapter_type: lmeval pip_packages: ["kubernetes"] config_class: llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig - module: llama_stack_provider_lmeval + module: llama_stack_provider_lmeval.remote api_dependencies: ["inference"] optional_api_dependencies: [] diff --git a/run-inline.yaml b/run-inline.yaml new file mode 100644 index 0000000..2e9b169 --- /dev/null +++ b/run-inline.yaml @@ -0,0 +1,30 @@ +version: "2" +image_name: trustyai-lmeval +apis: + - inference + - eval + - files +providers: + inference: + - provider_id: vllm + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8080/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=false} + eval: + - provider_id: trustyai_lmeval + provider_type: inline::trustyai_lmeval + config: + base_url: ${env.BASE_URL:=http://localhost:8321/v1} + use_k8s: ${env.USE_K8S:=false} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/trustyai-lmeval/files} + metadata_store: + type: sqlite + db_path: ${env.METADATA_STORE_DB_PATH:=~/.llama/distributions/trustyai-lmeval}/registry.db} +external_providers_dir: ./providers.d diff --git a/src/llama_stack_provider_lmeval/__init__.py b/src/llama_stack_provider_lmeval/__init__.py index 898cdca..e69de29 100644 --- a/src/llama_stack_provider_lmeval/__init__.py +++ b/src/llama_stack_provider_lmeval/__init__.py @@ -1,54 +0,0 @@ -import logging - -from llama_stack.apis.datatypes import Api -from llama_stack.providers.datatypes import ProviderSpec - -from .config import LMEvalEvalProviderConfig -from .lmeval import LMEval -from .provider import get_provider_spec - -# Set up logging -logger = logging.getLogger(__name__) - - -async def get_adapter_impl( - config: LMEvalEvalProviderConfig, - deps: dict[Api, ProviderSpec] | None = None, -) -> LMEval: - """Get an LMEval implementation from the configuration. - - Args: - config: LMEval configuration - deps: Optional dependencies for testing/injection - - Returns: - Configured LMEval implementation - - Raises: - Exception: If configuration is invalid - """ - try: - if deps is None: - deps = {} - - # Extract base_url from config if available - base_url = None - if hasattr(config, "model_args") and config.model_args: - for arg in config.model_args: - if arg.get("name") == "base_url": - base_url = arg.get("value") - logger.debug(f"Using base_url from config: {base_url}") - break - - return LMEval(config=config) - except Exception as e: - raise Exception(f"Failed to create LMEval implementation: {str(e)}") from e - - -__all__ = [ - # Factory methods - "get_adapter_impl", - # Configurations - "LMEval", - "get_provider_spec", -] diff --git a/src/llama_stack_provider_lmeval/config.py b/src/llama_stack_provider_lmeval/config.py index 4116adf..6920aae 100644 --- a/src/llama_stack_provider_lmeval/config.py +++ b/src/llama_stack_provider_lmeval/config.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from pathlib import Path from typing import Any from llama_stack.apis.eval import BenchmarkConfig, EvalCandidate @@ -125,6 +126,8 @@ class LMEvalEvalProviderConfig: metadata: dict[str, Any] | None = None # TLS configuration - structured approach tls: TLSConfig | None = None + base_dir: Path = Path(__file__).parent + results_dir: Path = base_dir / "results" def __post_init__(self): """Validate the configuration""" diff --git a/src/llama_stack_provider_lmeval/inline/__init__.py b/src/llama_stack_provider_lmeval/inline/__init__.py new file mode 100644 index 0000000..c1513cd --- /dev/null +++ b/src/llama_stack_provider_lmeval/inline/__init__.py @@ -0,0 +1,52 @@ +"""LMEval Inline Eval Llama Stack provider.""" + +import logging + +from llama_stack.apis.datatypes import Api +from llama_stack.providers.datatypes import ProviderSpec + +from llama_stack_provider_lmeval.config import LMEvalEvalProviderConfig + +from .lmeval import LMEvalInline + +logger = logging.getLogger(__name__) + + +async def get_provider_impl( + config: LMEvalEvalProviderConfig, + deps: dict[Api, ProviderSpec] | None = None, +) -> LMEvalInline: + """Get an inline Eval implementation from the configuration. + + Args: + config: LMEvalEvalProviderConfig + deps: Optional[dict[Api, Any]] = None - can be ProviderSpec or API instances + + Returns: + Configured LMEval Inline implementation + + Raises: + Exception: If configuration is invalid + """ + try: + if deps is None: + deps = {} + + # Extract base_url from config if available + base_url = None + if hasattr(config, "model_args") and config.model_args: + for arg in config.model_args: + if arg.get("name") == "base_url": + base_url = arg.get("value") + logger.debug("Using base_url from config: %s", base_url) + break + + return LMEvalInline(config=config, deps=deps) + except Exception as e: + raise RuntimeError(f"Failed to create LMEval implementation: {str(e)}") from e + + +__all__ = [ + "get_provider_impl", + "LMEvalInline", +] diff --git a/src/llama_stack_provider_lmeval/inline/lmeval.py b/src/llama_stack_provider_lmeval/inline/lmeval.py new file mode 100644 index 0000000..def4129 --- /dev/null +++ b/src/llama_stack_provider_lmeval/inline/lmeval.py @@ -0,0 +1,620 @@ +"""LMEval Inline Eval Provider implementation for Llama Stack.""" + +import asyncio +import json +import logging +import os +import signal +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any + +from llama_stack.apis.benchmarks import Benchmark, ListBenchmarksResponse +from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.datatypes import Api +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse +from llama_stack.apis.files import OpenAIFileObject, OpenAIFilePurpose, UploadFile +from llama_stack.apis.scoring import ScoringResult +from llama_stack.providers.datatypes import BenchmarksProtocolPrivate + +from ..config import LMEvalEvalProviderConfig +from ..errors import LMEvalConfigError + +logger = logging.getLogger(__name__) + + +class LMEvalInline(Eval, BenchmarksProtocolPrivate): + """LMEval inline provider implementation.""" + + def __init__( + self, config: LMEvalEvalProviderConfig, deps: dict[Api, Any] | None = None + ): + self.config: LMEvalEvalProviderConfig = config + self.benchmarks: dict[str, Benchmark] = {} + self._jobs: list[Job] = [] + self._job_metadata: dict[str, dict[str, str]] = {} + self.files_api = deps.get(Api.files) if deps else None + + async def initialize(self): + "Initialize the LMEval Inline provider" + if not self.files_api: + raise LMEvalConfigError("Files API is not initialized") + + async def list_benchmarks(self) -> ListBenchmarksResponse: + """List all registered benchmarks.""" + return ListBenchmarksResponse(data=list(self.benchmarks.values())) + + async def get_benchmark(self, benchmark_id: str) -> Benchmark | None: + """Get a specific benchmark by ID.""" + benchmark = self.benchmarks.get(benchmark_id) + return benchmark + + async def register_benchmark(self, benchmark: Benchmark) -> None: + """Register a benchmark for evaluation.""" + self.benchmarks[benchmark.identifier] = benchmark + + def _get_job_id(self) -> str: + """Generate a unique job ID.""" + return str(uuid.uuid4()) + + async def run_eval( + self, benchmark_id: str, benchmark_config: BenchmarkConfig, limit="2" + ) -> Job: + if not isinstance(benchmark_config, BenchmarkConfig): + raise LMEvalConfigError("LMEval requires BenchmarkConfig") + + stored_benchmark = await self.get_benchmark(benchmark_id) + + logger.info("Running evaluation for benchmark: %s", stored_benchmark) + + if ( + not hasattr(benchmark_config, "num_examples") + or benchmark_config.num_examples is None + ): + config_limit = None + else: + config_limit = str(benchmark_config.num_examples) + + job_output_results_dir: Path = self.config.results_dir + job_output_results_dir.mkdir(parents=True, exist_ok=True) + + # Generate unique job ID - use the same ID for both file naming and job tracking + job_id = self._get_job_id() + job_uuid = job_id.replace("-", "") + + try: + cmd = self.build_command( + benchmark_id=benchmark_id, + task_config=benchmark_config, + limit=config_limit or limit, + stored_benchmark=stored_benchmark, + job_output_results_dir=job_output_results_dir, + job_uuid=job_uuid, + ) + + logger.debug("Generated command for benchmark: %s", benchmark_id) + job = Job( + job_id=job_id, + status=JobStatus.scheduled, + metadata={"created_at": datetime.now().isoformat(), "process_id": None}, + ) + self._jobs.append(job) + self._job_metadata[job_id] = {} + + env = os.environ.copy() + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + self._job_metadata[job_id]["process_id"] = str(process.pid) + + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300) + + if process.returncode == 0: + # Log successful completion + logger.info("Evaluation completed successfully for job %s", job_id) + # Check if the result file exists and process it + result_files = list( + job_output_results_dir.glob(f"job_{job_uuid}_results_*.json") + ) + if result_files: + # Use the most recent file if multiple exist + actual_result_file = max( + result_files, key=lambda f: f.stat().st_mtime + ) + logger.info("Found results file: %s", actual_result_file) + + # Parse results from the local file that lm_eval wrote to + try: + # Read and parse the results file + with open(actual_result_file, encoding="utf-8") as f: + results_data = json.load(f) + + # Parse the results using the existing method + parsed_results = await self._parse_job_results_from_data( + results_data, job_id + ) + # Store the parsed results in job metadata + self._job_metadata[job_id]["results"] = parsed_results + + # Upload the original lm_eval results file to Files API + upload_job_result: OpenAIFileObject = await self._upload_file( + actual_result_file, OpenAIFilePurpose.ASSISTANTS + ) + + if upload_job_result: + self._job_metadata[job_id]["uploaded_file"] = ( + upload_job_result.id + ) + logger.info( + "Uploaded job result file %s to Files API with ID: %s", + actual_result_file, + upload_job_result.id, + ) + else: + logger.warning( + "Failed to upload job result file %s to Files API", + actual_result_file, + ) + + job.status = JobStatus.completed + except Exception as e: + logger.error( + "Failed to process results file for job %s: %s", job_id, e + ) + job.status = JobStatus.failed + self._job_metadata[job_id]["error"] = ( + f"Failed to process results: {str(e)}" + ) + else: + logger.warning( + "No results files found for job %s in directory %s", + job_id, + job_output_results_dir, + ) + job.status = JobStatus.failed + self._job_metadata[job_id]["error"] = "Results file not found" + else: + logger.error( + "LM-Eval process failed with return code %d", process.returncode + ) + logger.error("stdout: %s", stdout.decode("utf-8") if stdout else "") + logger.error("stderr: %s", stderr.decode("utf-8") if stderr else "") + job.status = JobStatus.failed + self._job_metadata[job_id]["error"] = f""" + Process failed with return code {process.returncode} + """ + except Exception as e: + job.status = JobStatus.failed + self._job_metadata[job_id]["error"] = str(e) + logger.error("Job %s failed with error: %s", job_id, e) + # Only terminate if process is still running + if "process" in locals() and process and process.returncode is None: + try: + process.terminate() + except Exception as term_e: + logger.warning("Failed to terminate process: %s", term_e) + finally: + # Clean up any remaining process + if "process" in locals() and process and process.returncode is None: + process.kill() + await process.wait() + # Clean up job file + self._cleanup_job_files(job_output_results_dir, job_uuid) + + return job + + def _cleanup_job_files(self, results_dir: Path, job_uuid: str) -> None: + """Clean up result files for a specific job. + + Args: + results_dir: The results directory + job_uuid: The job UUID to clean up files for + """ + try: + # Find and remove all job files + job_files = list(results_dir.glob(f"job_{job_uuid}_results*.json")) + for file_path in job_files: + try: + file_path.unlink() + logger.debug("Deleted job result file: %s", file_path) + except OSError as e: + logger.warning( + "Failed to delete job result file %s: %s", file_path, e + ) + except Exception as e: + logger.warning("Error during job file cleanup for %s: %s", job_uuid, e) + + async def _upload_file( + self, file: Path, purpose: OpenAIFilePurpose + ) -> OpenAIFileObject | None: + if self.files_api is None: + logger.warning("Files API not available, cannot upload file %s", file) + return None + + if file.exists(): + with open(file, "rb") as f: + upload_file = await self.files_api.openai_upload_file( + file=UploadFile(file=f, filename=file.name), purpose=purpose + ) + return upload_file + else: + logger.warning("File %s does not exist", file) + return None + + async def _parse_job_results_from_data( + self, results_data: dict, job_id: str + ) -> EvaluateResponse: + if not results_data: + logger.warning("No results data for job %s", job_id) + return EvaluateResponse(generations=[], scores={}) + try: + # Extract generations and scores from lm_eval results + generations: list[dict[str, Any]] = [] + scores: dict[str, ScoringResult] = {} + + if "results" in results_data: + results = results_data["results"] + + # Extract scores for each task + for task_name, task_results in results.items(): + if isinstance(task_results, dict): + # Extract metric scores + for metric_name, metric_value in task_results.items(): + if isinstance(metric_value, int | float): + score_key = f"{task_name}:{metric_name}" + scores[score_key] = ScoringResult( + aggregated_results={ + metric_name: float(metric_value) + }, + score_rows=[{"score": float(metric_value)}], + ) + + # Extract generations if available (from samples) + if "samples" in results_data: + samples = results_data["samples"] + for task_name, task_samples in samples.items(): + if isinstance(task_samples, list): + for sample in task_samples: + if isinstance(sample, dict): + generation = { + "task": task_name, + "input": sample.get("doc", {}), + "output": sample.get("target", ""), + "generated": sample.get("resps", []), + } + generations.append(generation) + + logger.info("Successfully parsed results from file for job %s", job_id) + + return EvaluateResponse( + generations=generations, + scores=scores, + metadata={"job_id": job_id}, + ) + + except Exception as e: + logger.error( + "Failed to parse job results from file for job %s: %s", job_id, e + ) + return EvaluateResponse(generations=[], scores={}) + + def _create_model_args(self, base_url: str, benchmark_config: BenchmarkConfig): + model_args = {"model": None, "base_url": base_url} + + model_name = None + if hasattr(benchmark_config, "model") and benchmark_config.model: + model_name = benchmark_config.model + elif ( + hasattr(benchmark_config, "eval_candidate") + and benchmark_config.eval_candidate + ): + if ( + hasattr(benchmark_config.eval_candidate, "model") + and benchmark_config.eval_candidate.model + ): + model_name = benchmark_config.eval_candidate.model + + # Set model name and default parameters if we have a model + if model_name: + model_args["model"] = model_name + model_args["num_concurrent"] = "1" + model_args["max_retries"] = "3" + + # Apply any custom model args + if hasattr(benchmark_config, "model_args") and benchmark_config.model_args: + for arg in benchmark_config.model_args: + model_args[arg.name] = arg.value + + return model_args + + def _collect_lmeval_args( + self, task_config: BenchmarkConfig, stored_benchmark: Benchmark | None + ): + lmeval_args = {} + if hasattr(task_config, "lmeval_args") and task_config.lmeval_args: + lmeval_args = task_config.lmeval_args + + if hasattr(task_config, "metadata") and task_config.metadata: + metadata_lmeval_args = task_config.metadata.get("lmeval_args") + if metadata_lmeval_args: + for key, value in metadata_lmeval_args.items(): + lmeval_args[key] = value + + # Check stored benchmark for additional lmeval args + if ( + stored_benchmark + and hasattr(stored_benchmark, "metadata") + and stored_benchmark.metadata + ): + benchmark_lmeval_args = stored_benchmark.metadata.get("lmeval_args") + if benchmark_lmeval_args: + for key, value in benchmark_lmeval_args.items(): + lmeval_args[key] = value + + return lmeval_args + + def build_command( + self, + task_config: BenchmarkConfig, + benchmark_id: str, + limit: str, + stored_benchmark: Benchmark | None, + job_output_results_dir: Path, + job_uuid: str, + ) -> list[str]: + """Build lm_eval command with default args and user overrides.""" + + logger.info( + "BUILD_COMMAND: Starting to build command for benchmark: %s", benchmark_id + ) + logger.info("BUILD_COMMAND: Task config type: %s", type(task_config)) + logger.info( + "BUILD_COMMAND: Task config has metadata: %s", + hasattr(task_config, "metadata"), + ) + if hasattr(task_config, "metadata"): + logger.info( + "BUILD_COMMAND: Task config metadata content: %s", task_config.metadata + ) + + eval_candidate = task_config.eval_candidate + if not eval_candidate.type == "model": + raise LMEvalConfigError("LMEval only supports model candidates for now") + + # Create model args - use VLLM_URL environment variable for inference provider + inference_url = os.environ.get("VLLM_URL", "http://localhost:8080/v1") + openai_url = inference_url.replace("/v1", "/v1/completions") + model_args = self._create_model_args(openai_url, task_config) + + if ( + stored_benchmark is not None + and hasattr(stored_benchmark, "metadata") + and stored_benchmark.metadata + and "tokenizer" in stored_benchmark.metadata + ): + tokenizer_value = stored_benchmark.metadata.get("tokenizer") + if isinstance(tokenizer_value, str) and tokenizer_value: + logger.info( + "BUILD_COMMAND: Using custom tokenizer from metadata: %s", + tokenizer_value, + ) + + tokenized_requests = stored_benchmark.metadata.get("tokenized_requests") + if isinstance(tokenized_requests, bool) and tokenized_requests: + logger.info( + "BUILD_COMMAND: Using custom tokenized_requests from metadata: %s", + tokenized_requests, + ) + + model_args["tokenizer"] = tokenizer_value + model_args["tokenized_requests"] = tokenized_requests + + lmeval_args = self._collect_lmeval_args(task_config, stored_benchmark) + + # Start building the command + cmd = ["lm_eval"] + + # Add model type + model_type = model_args.get("model_type", "local-completions") + cmd.extend(["--model", model_type]) + + # Build model_args string + if model_args: + model_args_list = [] + for key, value in model_args.items(): + if key != "model_type" and value is not None: + model_args_list.append(f"{key}={value}") + + if model_args_list: + cmd.extend(["--model_args", ",".join(model_args_list)]) + + # Extract task name from benchmark_id (remove provider prefix) + # benchmark_id format: "inline::trustyai_lmeval::task_name" + task_name = ( + benchmark_id.split("::")[-1] if "::" in benchmark_id else benchmark_id + ) + cmd.extend(["--tasks", task_name]) + + cmd.extend(["--limit", limit]) + + cmd.extend( + ["--output_path", f"{job_output_results_dir}/job_{job_uuid}_results.json"] + ) + + # Add lmeval_args + if lmeval_args: + for key, value in lmeval_args.items(): + if value is not None: + cmd.extend([key, value]) + + logger.info( + "BUILD_COMMAND: Generated command for benchmark %s: %s", + benchmark_id, + " ".join(cmd), + ) + return cmd + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: list[dict[str, Any]], + scoring_functions: list[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + """Evaluate a list of rows on a benchmark. + + Args: + benchmark_id: The ID of the benchmark to run the evaluation on. + input_rows: The rows to evaluate. + scoring_functions: The scoring functions to use for the evaluation. + benchmark_config: The configuration for the benchmark. + + Returns: + EvaluateResponse: Object containing generations and scores + """ + + raise NotImplementedError( + "Evaluate rows is not implemented, use run_eval instead" + ) + + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: + """Cancel a running evaluation job. + + Args: + benchmark_id: The ID of the benchmark to run the evaluation on. + job_id: The ID of the job to cancel. + """ + job = next((j for j in self._jobs if j.job_id == job_id), None) + if not job: + logger.warning("Job %s not found", job_id) + return + + if job.status in [JobStatus.completed, JobStatus.failed, JobStatus.cancelled]: + logger.warning("Job %s is not running", job_id) + return + + if job.status in [JobStatus.in_progress, JobStatus.scheduled]: + process_id_str = self._job_metadata.get(job_id, {}).get("process_id") + if process_id_str: + process_id = int(process_id_str) + logger.info("Attempting to cancel subprocess %s", process_id) + + try: + os.kill(process_id, signal.SIGTERM) + logger.info("Sent SIGTERM to process %s", process_id) + + for _ in range(10): + try: + # Check if process is still running + os.kill(process_id, 0) + await asyncio.sleep(0.5) + except ProcessLookupError: + # Process has terminated + logger.info("Process %s terminated gracefully", process_id) + break + else: + # Process didn't terminate gracefully, force kill + try: + os.kill(process_id, signal.SIGKILL) + logger.info("Sent SIGKILL to process %s", process_id) + except ProcessLookupError: + logger.info("Process %s already terminated", process_id) + + except ProcessLookupError: + logger.warning( + "Process %s not found (may have already terminated)", process_id + ) + except OSError as e: + logger.error("Error terminating process %s: %s", process_id, e) + + job.status = JobStatus.cancelled + + # Clean up result files for cancelled job + job_uuid = job_id.replace("-", "") + self._cleanup_job_files(self.config.results_dir, job_uuid) + + logger.info("Successfully cancelled job %s", job_id) + + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: + """Get the results of a completed evaluation job. + + Args: + benchmark_id: The ID of the benchmark to run the evaluation on. + job_id: The ID of the job to get the results of. + """ + job = await self.job_status(benchmark_id, job_id) + + if job is None: + logger.warning("Job %s not found", job_id) + return EvaluateResponse(generations=[], scores={}) + + if job.status == JobStatus.completed: + # Get results from job metadata + job_metadata = self._job_metadata.get(job_id, {}) + results = job_metadata.get("results") + return results + + async def job_status(self, benchmark_id: str, job_id: str) -> Job | None: + """Get the status of a running evaluation job. + + Args: + benchmark_id: The ID of the benchmark to run the evaluation on. + job_id: The ID of the job to get the status of. + """ + job = next((j for j in self._jobs if j.job_id == job_id), None) + + if job is None: + logger.warning("Job %s not found", job_id) + return None + + return job + + async def shutdown(self) -> None: + """Shutdown the LMEval Inline provider.""" + logger.info("Shutting down LMEval Inline provider") + + # Cancel all running jobs + running_jobs = [ + job + for job in self._jobs + if job.status in [JobStatus.in_progress, JobStatus.scheduled] + ] + + if running_jobs: + logger.info("Cancelling %d running jobs", len(running_jobs)) + for job in running_jobs: + try: + await self.job_cancel(benchmark_id="", job_id=job.job_id) + except Exception as e: + logger.warning("Failed to cancel job %s: %s", job.job_id, e) + await asyncio.sleep(0.1) # Brief pause between cancellations + + # Clean up any remaining result files + if self.config.results_dir.exists(): + try: + # Clean up result files for all known jobs + for job_id in list(self._job_metadata.keys()): + job_uuid = job_id.replace("-", "") + self._cleanup_job_files(self.config.results_dir, job_uuid) + except Exception as e: + logger.warning("Error during shutdown cleanup: %s", e) + + # Clear internal state + self._jobs.clear() + self._job_metadata.clear() + self.benchmarks.clear() + + # Close files API connection if it exists and has cleanup methods + if self.files_api and hasattr(self.files_api, "close"): + try: + await self.files_api.close() + logger.debug("Closed Files API connection") + except Exception as e: + logger.warning("Failed to close Files API connection: %s", e) + + logger.info("LMEval Inline provider shutdown complete") diff --git a/src/llama_stack_provider_lmeval/inline/provider.py b/src/llama_stack_provider_lmeval/inline/provider.py new file mode 100644 index 0000000..f289423 --- /dev/null +++ b/src/llama_stack_provider_lmeval/inline/provider.py @@ -0,0 +1,13 @@ +"""LMEval Inline Eval Llama Stack provider specification.""" + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def get_provider_spec() -> ProviderSpec: + return InlineProviderSpec( + api=Api.eval, + provider_type="inline::trustyai_lmeval", + pip_packages=["lm-eval", "lm-eval[api]"], + config_class="llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig", + module="llama_stack_provider_lmeval.inline", + ) diff --git a/src/llama_stack_provider_lmeval/remote/__init__.py b/src/llama_stack_provider_lmeval/remote/__init__.py new file mode 100644 index 0000000..77e2f6d --- /dev/null +++ b/src/llama_stack_provider_lmeval/remote/__init__.py @@ -0,0 +1,54 @@ +import logging + +from llama_stack.apis.datatypes import Api +from llama_stack.providers.datatypes import ProviderSpec + +from ..config import LMEvalEvalProviderConfig +from .lmeval import LMEval +from .provider import get_provider_spec + +# Set up logging +logger = logging.getLogger(__name__) + + +async def get_adapter_impl( + config: LMEvalEvalProviderConfig, + deps: dict[Api, ProviderSpec] | None = None, +) -> LMEval: + """Get an LMEval implementation from the configuration. + + Args: + config: LMEval configuration + deps: Optional dependencies for testing/injection + + Returns: + Configured LMEval implementation + + Raises: + Exception: If configuration is invalid + """ + try: + if deps is None: + deps = {} + + # Extract base_url from config if available + base_url = None + if hasattr(config, "model_args") and config.model_args: + for arg in config.model_args: + if arg.get("name") == "base_url": + base_url = arg.get("value") + logger.debug(f"Using base_url from config: {base_url}") + break + + return LMEval(config=config) + except Exception as e: + raise Exception(f"Failed to create LMEval implementation: {str(e)}") from e + + +__all__ = [ + # Factory methods + "get_adapter_impl", + # Configurations + "LMEval", + "get_provider_spec", +] diff --git a/src/llama_stack_provider_lmeval/lmeval.py b/src/llama_stack_provider_lmeval/remote/lmeval.py similarity index 99% rename from src/llama_stack_provider_lmeval/lmeval.py rename to src/llama_stack_provider_lmeval/remote/lmeval.py index 27ea738..c722f20 100644 --- a/src/llama_stack_provider_lmeval/lmeval.py +++ b/src/llama_stack_provider_lmeval/remote/lmeval.py @@ -21,8 +21,8 @@ from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from pydantic import BaseModel -from .config import LMEvalEvalProviderConfig -from .errors import LMEvalConfigError, LMEvalTaskNameError +from ..config import LMEvalEvalProviderConfig +from ..errors import LMEvalConfigError, LMEvalTaskNameError logger = logging.getLogger(__name__) @@ -150,7 +150,7 @@ def _create_tls_volume_config( # Create TLSConfig object from environment variables for validation try: - from .config import TLSConfig + from ..config import TLSConfig tls_config = TLSConfig( enable=True, cert_file=cert_file, cert_secret=cert_secret diff --git a/src/llama_stack_provider_lmeval/provider.py b/src/llama_stack_provider_lmeval/remote/provider.py similarity index 88% rename from src/llama_stack_provider_lmeval/provider.py rename to src/llama_stack_provider_lmeval/remote/provider.py index 35eae24..8811d36 100644 --- a/src/llama_stack_provider_lmeval/provider.py +++ b/src/llama_stack_provider_lmeval/remote/provider.py @@ -13,6 +13,6 @@ def get_provider_spec() -> ProviderSpec: adapter_type="lmeval", pip_packages=["kubernetes"], config_class="llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig", - module="llama_stack_provider_lmeval", + module="llama_stack_provider_lmeval.remote", ), ) diff --git a/tests/test_base_url_handling.py b/tests/test_base_url_handling.py index 4ef2a4f..c544c34 100644 --- a/tests/test_base_url_handling.py +++ b/tests/test_base_url_handling.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from src.llama_stack_provider_lmeval.config import TLSConfig, LMEvalConfigError -from src.llama_stack_provider_lmeval.lmeval import LMEvalCRBuilder, ModelArg +from src.llama_stack_provider_lmeval.remote.lmeval import LMEvalCRBuilder, ModelArg BASE_URL = "http://example.com" @@ -259,4 +259,4 @@ def test_create_model_args_fallback_to_provider_tls_when_benchmark_tls_none(self if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() \ No newline at end of file diff --git a/tests/test_env_vars.py b/tests/test_env_vars.py index eed7c28..8e399ff 100644 --- a/tests/test_env_vars.py +++ b/tests/test_env_vars.py @@ -4,7 +4,7 @@ sys.path.insert(0, "src") -from llama_stack_provider_lmeval.lmeval import LMEvalCRBuilder +from llama_stack_provider_lmeval.remote.lmeval import LMEvalCRBuilder class TestEnvironmentVariables(unittest.TestCase): @@ -387,7 +387,7 @@ def test_logging_debug_output(self): }, ] - with patch("llama_stack_provider_lmeval.lmeval.logger") as mock_logger: + with patch("llama_stack_provider_lmeval.remote.lmeval.logger") as mock_logger: pod_config = self.cr_builder._create_pod_config(env_vars) # Verify that debug logging was called @@ -398,15 +398,15 @@ def test_logging_debug_output(self): call_args = mock_logger.debug.call_args format_string = call_args[0][0] # First argument is the format string format_args = call_args[0][1:] # Remaining arguments are the values - + self.assertEqual(format_string, "Setting pod environment variables: %s") self.assertEqual(len(format_args), 1) - + # The formatted message should contain the variable names formatted_message = format_string % format_args[0] self.assertIn("TEST_VAR", formatted_message) self.assertIn("SECRET_VAR", formatted_message) - + # Ensure no sensitive data is logged self.assertNotIn("test_value", formatted_message) self.assertNotIn("my-secret", formatted_message) diff --git a/tests/test_lmeval.py b/tests/test_lmeval.py index f566883..ffa3d57 100644 --- a/tests/test_lmeval.py +++ b/tests/test_lmeval.py @@ -5,7 +5,7 @@ import os from src.llama_stack_provider_lmeval.config import LMEvalEvalProviderConfig, TLSConfig -from src.llama_stack_provider_lmeval.lmeval import LMEvalCRBuilder, _get_tls_config_from_env +from src.llama_stack_provider_lmeval.remote.lmeval import LMEvalCRBuilder, _get_tls_config_from_env class TestTLSConfigFromEnv(unittest.TestCase): @@ -242,14 +242,14 @@ def test_case_insensitive_tls_disabled(self): result = _get_tls_config_from_env() self.assertIsNone(result) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_empty_cert_file(self, mock_logger): """Test TLS volume config when cert_file is empty string.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None due to validation failure @@ -261,14 +261,14 @@ def test_tls_volume_config_with_empty_cert_file(self, mock_logger): warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] self.assertTrue(any("TLS configuration validation failed" in call for call in warning_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_empty_cert_secret(self, mock_logger): """Test TLS volume config when cert_secret is empty string.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "test-cert.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None due to validation failure @@ -280,14 +280,14 @@ def test_tls_volume_config_with_empty_cert_secret(self, mock_logger): warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] self.assertTrue(any("TLS configuration validation failed" in call for call in warning_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_path_traversal_in_cert_file(self, mock_logger): """Test TLS volume config when cert_file contains path traversal characters.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "../malicious.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None due to validation failure @@ -299,14 +299,14 @@ def test_tls_volume_config_with_path_traversal_in_cert_file(self, mock_logger): warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] self.assertTrue(any("TLS configuration validation failed" in call for call in warning_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_unsafe_characters_in_cert_file(self, mock_logger): """Test TLS volume config when cert_file contains unsafe characters.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "test@cert.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None due to validation failure @@ -318,14 +318,14 @@ def test_tls_volume_config_with_unsafe_characters_in_cert_file(self, mock_logger warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] self.assertTrue(any("TLS configuration validation failed" in call for call in warning_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_valid_cert_file(self, mock_logger): """Test TLS volume config with valid certificate file name.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "test-cert.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return valid volume configuration @@ -353,14 +353,14 @@ def test_tls_volume_config_with_valid_cert_file(self, mock_logger): info_call_args = mock_logger.info.call_args[0] self.assertIn("Created TLS volume config", info_call_args[0]) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_valid_subdirectory_path(self, mock_logger): """Test TLS volume config with valid subdirectory path in cert_file.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "certs/ca-bundle.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return valid volume configuration @@ -377,14 +377,14 @@ def test_tls_volume_config_with_valid_subdirectory_path(self, mock_logger): # Should log info about successful creation mock_logger.info.assert_called_once() - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_whitespace_only_cert_file(self, mock_logger): """Test TLS volume config when cert_file contains only whitespace.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = " " os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None due to validation failure @@ -396,14 +396,14 @@ def test_tls_volume_config_with_whitespace_only_cert_file(self, mock_logger): warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] self.assertTrue(any("TLS configuration validation failed" in call for call in warning_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_valid_special_characters(self, mock_logger): """Test TLS volume config with valid certificate file name containing dots, hyphens, and underscores.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" os.environ["TRUSTYAI_LMEVAL_CERT_FILE"] = "test-cert-v1.2.pem" os.environ["TRUSTYAI_LMEVAL_CERT_SECRET"] = "test-secret" - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return valid volume configuration @@ -420,13 +420,13 @@ def test_tls_volume_config_with_valid_special_characters(self, mock_logger): # Should log info about successful creation mock_logger.info.assert_called_once() - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_tls_volume_config_with_no_certificates(self, mock_logger): """Test TLS volume config when TLS is enabled but no certificates specified (verify=True case).""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" # Neither TRUSTYAI_LMEVAL_CERT_FILE nor TRUSTYAI_LMEVAL_CERT_SECRET are set - from src.llama_stack_provider_lmeval.lmeval import _create_tls_volume_config + from src.llama_stack_provider_lmeval.remote.lmeval import _create_tls_volume_config volume_mounts, volumes = _create_tls_volume_config() # Should return None, None since no volumes are needed for verify=True @@ -438,7 +438,7 @@ def test_tls_volume_config_with_no_certificates(self, mock_logger): debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list] self.assertTrue(any("no certificates specified, no volumes created" in call for call in debug_calls)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_only_cert_file_set(self, mock_logger): """Test that appropriate error logging occurs when only cert_file is set.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -457,7 +457,7 @@ def test_logging_when_only_cert_file_set(self, mock_logger): # Verify result is None (no fallback) self.assertIsNone(result) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_only_cert_secret_set(self, mock_logger): """Test that appropriate error logging occurs when only cert_secret is set.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -476,7 +476,7 @@ def test_logging_when_only_cert_secret_set(self, mock_logger): # Verify result is None (no fallback) self.assertIsNone(result) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_fallback_to_provider_config_successful(self, mock_logger): """Test that warning logging occurs when falling back to provider config.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -502,7 +502,7 @@ def test_logging_when_fallback_to_provider_config_successful(self, mock_logger): expected_path = "/etc/ssl/certs/provider-cert.pem" self.assertEqual(result, expected_path) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_fallback_to_provider_config_tls_true(self, mock_logger): """Test that warning logging occurs when falling back to provider config TLS=True.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -527,7 +527,7 @@ def test_logging_when_fallback_to_provider_config_tls_true(self, mock_logger): # Verify result is True self.assertTrue(result) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_fallback_not_possible(self, mock_logger): """Test that error logging occurs when fallback to provider config is not possible.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -548,7 +548,7 @@ def test_logging_when_fallback_not_possible(self, mock_logger): # Verify result is None self.assertIsNone(result) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_using_environment_variables(self, mock_logger): """Test that debug logging occurs when using complete environment variables.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -566,7 +566,7 @@ def test_logging_when_using_environment_variables(self, mock_logger): expected_path = "/etc/ssl/certs/test-cert.pem" self.assertEqual(result, expected_path) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_logging_when_no_cert_variables_set(self, mock_logger): """Test that debug logging occurs when no certificate variables are set.""" os.environ["TRUSTYAI_LMEVAL_TLS"] = "true" @@ -615,14 +615,14 @@ def test_create_cr_with_model_in_eval_candidate(self): # Create a benchmark config without direct model attribute benchmark_config = MagicMock() - + # Create eval_candidate as a simple object with the required attributes class EvalCandidate: def __init__(self): self.type = "model" self.model = "eval-candidate-model" self.sampling_params = {} - + eval_candidate = EvalCandidate() benchmark_config.eval_candidate = eval_candidate benchmark_config.env_vars = [] @@ -641,12 +641,12 @@ def __init__(self): model_args = cr.get("spec", {}).get("modelArgs", []) model_arg = next((arg for arg in model_args if arg.get("name") == "model"), None) - + self.assertIsNotNone(model_arg, "Model argument should be present in modelArgs") - self.assertEqual(model_arg.get("value"), "eval-candidate-model", + self.assertEqual(model_arg.get("value"), "eval-candidate-model", "Model value should be extracted from eval_candidate.model") - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_without_tls(self, mock_logger): """Creating CR without no TLS configuration.""" config = LMEvalEvalProviderConfig( @@ -674,7 +674,7 @@ def test_create_cr_without_tls(self, mock_logger): "CR TLS configuration should be missing when not provided in the configuration", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_tls_false(self, mock_logger): """Creating CR with TLS verification bypass.""" config = LMEvalEvalProviderConfig( @@ -702,7 +702,7 @@ def test_create_cr_with_tls_false(self, mock_logger): "CR TLS configuration should be missing when not provided in the configuration", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_tls_certificate_path(self, mock_logger): """Creating CR with TLS certificate path.""" config = LMEvalEvalProviderConfig( @@ -730,7 +730,7 @@ def test_create_cr_with_tls_certificate_path(self, mock_logger): "TLS configuration should be missing when not provided in the configuration", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") @patch.dict(os.environ, {"TRUSTYAI_LMEVAL_TLS": "true"}) def test_create_cr_with_env_tls_true(self, mock_logger): """Creating CR with TLS verification enabled via environment variable.""" @@ -764,7 +764,7 @@ def test_create_cr_with_env_tls_true(self, mock_logger): "TLS configuration value should be 'True'", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") @patch.dict(os.environ, { "TRUSTYAI_LMEVAL_TLS": "true", "TRUSTYAI_LMEVAL_CERT_FILE": "custom-ca.pem", @@ -807,7 +807,7 @@ def test_create_cr_with_env_tls_certificate(self, mock_logger): self.assertIsNotNone(pod_config.get("volumes"), "Pod should have volumes for TLS certificate") self.assertIsNotNone(pod_config.get("container", {}).get("volumeMounts"), "Container should have volume mounts for TLS certificate") - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_provider_config_tls_true(self, mock_logger): """Creating CR with TLS verification enabled via provider config (backward compatibility).""" config = LMEvalEvalProviderConfig( @@ -841,7 +841,7 @@ def test_create_cr_with_provider_config_tls_true(self, mock_logger): "TLS configuration value should be 'True'", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_provider_config_tls_certificate(self, mock_logger): """Creating CR with TLS certificate path via provider config (backward compatibility).""" config = LMEvalEvalProviderConfig( @@ -1097,7 +1097,7 @@ def test_create_cr_with_provider_config_tls_unsafe_characters_still_blocked(self # Verify the error message - should mention potentially unsafe characters self.assertIn("contains potentially unsafe characters", str(excinfo.exception)) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_provider_config_tls_no_certificates(self, mock_logger): """Test TLS enabled but no certificates specified (verify=True case), should work correctly.""" # This should work without raising validation errors @@ -1143,7 +1143,7 @@ def test_create_cr_with_provider_config_tls_no_certificates(self, mock_logger): self.assertIsNone(pod_config.get("volumes"), "Pod should not have volumes for verify=True case") self.assertIsNone(pod_config.get("container", {}).get("volumeMounts"), "Container should not have volume mounts for verify=True case") - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_without_tokenizer(self, mock_logger): """Creating CR without tokenizer specified.""" config = LMEvalEvalProviderConfig( @@ -1171,7 +1171,7 @@ def test_create_cr_without_tokenizer(self, mock_logger): "Tokenizer should not be present when not specified in metadata", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_custom_tokenizer(self, mock_logger): """Creating CR with custom tokenizer specified in metadata.""" config = LMEvalEvalProviderConfig( @@ -1205,7 +1205,7 @@ def test_create_cr_with_custom_tokenizer(self, mock_logger): "Tokenizer value should match the value specified in the request's metadata", ) - @patch("src.llama_stack_provider_lmeval.lmeval.logger") + @patch("src.llama_stack_provider_lmeval.remote.lmeval.logger") def test_create_cr_with_tokenized_requests(self, mock_logger): """Creating CR with tokenized_requests specified in metadata.""" config = LMEvalEvalProviderConfig( diff --git a/tests/test_lmeval_inline.py b/tests/test_lmeval_inline.py new file mode 100644 index 0000000..2fcb13d --- /dev/null +++ b/tests/test_lmeval_inline.py @@ -0,0 +1,398 @@ +"""Comprehensive unit tests for LMEval inline provider.""" +import subprocess +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.common.job_types import JobStatus +from llama_stack.apis.datatypes import Api +from llama_stack.apis.eval import BenchmarkConfig + +from src.llama_stack_provider_lmeval.config import LMEvalEvalProviderConfig +from src.llama_stack_provider_lmeval.errors import ( + LMEvalConfigError, + LMEvalTaskNameError, +) +from src.llama_stack_provider_lmeval.inline.lmeval import LMEvalInline +from src.llama_stack_provider_lmeval.inline.provider import get_provider_spec + + +def create_mock_files_api(): + """Create a mock Files API with required methods.""" + mock_files_api = MagicMock() + mock_files_api.openai_upload_file = AsyncMock() + return mock_files_api + + +class TestLMEvalInlineProvider(unittest.TestCase): + """Unit tests for the LMEvalInlineProvider.""" + + def test_get_provider_spec(self): + """Test that provider spec is correctly configured.""" + spec = get_provider_spec() + assert spec is not None + assert spec.api == Api.eval + assert spec.pip_packages == ["lm-eval", "lm-eval[api]"] + assert spec.config_class == "llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig" + assert spec.module == "llama_stack_provider_lmeval.inline" + + +class TestLMEvalInlineInitialization(unittest.IsolatedAsyncioTestCase): + """Test LMEvalInline class initialization and basic functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + async def test_initialization(self): + """Test that LMEvalInline initializes correctly.""" + assert self.provider.config == self.config + assert isinstance(self.provider.benchmarks, dict) + assert len(self.provider.benchmarks) == 0 + assert isinstance(self.provider._jobs, list) + assert len(self.provider._jobs) == 0 + assert isinstance(self.provider._job_metadata, dict) + assert len(self.provider._job_metadata) == 0 + + async def test_initialize_method(self): + """Test that initialize method completes without error.""" + await self.provider.initialize() + # Should complete without raising any exceptions + + +class TestBenchmarkManagement(unittest.IsolatedAsyncioTestCase): + """Test benchmark management functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + # Create test benchmark + self.test_benchmark = Benchmark( + identifier="lmeval::mmlu", + dataset_id="trustyai_lmeval::mmlu", + scoring_functions=[], + provider_id="inline::trustyai_lmeval", + ) + + async def test_register_benchmark(self): + """Test benchmark registration.""" + await self.provider.register_benchmark(self.test_benchmark) + + assert "lmeval::mmlu" in self.provider.benchmarks + assert self.provider.benchmarks["lmeval::mmlu"] == self.test_benchmark + + async def test_get_benchmark_existing(self): + """Test getting an existing benchmark.""" + await self.provider.register_benchmark(self.test_benchmark) + + result = await self.provider.get_benchmark("lmeval::mmlu") + assert result == self.test_benchmark + + async def test_get_benchmark_nonexistent(self): + """Test getting a non-existent benchmark.""" + result = await self.provider.get_benchmark("nonexistent::benchmark") + assert result is None + + async def test_list_benchmarks_empty(self): + """Test listing benchmarks when none are registered.""" + result = await self.provider.list_benchmarks() + assert len(result.data) == 0 + + async def test_list_benchmarks_with_data(self): + """Test listing benchmarks when some are registered.""" + await self.provider.register_benchmark(self.test_benchmark) + + result = await self.provider.list_benchmarks() + assert len(result.data) == 1 + assert result.data[0] == self.test_benchmark + + +class TestJobManagement(unittest.IsolatedAsyncioTestCase): + """Test job management functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + def test_get_job_id(self): + """Test job ID generation.""" + job_id1 = self.provider._get_job_id() + job_id2 = self.provider._get_job_id() + + assert isinstance(job_id1, str) + assert isinstance(job_id2, str) + assert job_id1 != job_id2 + +class TestModelArgsCreation(unittest.IsolatedAsyncioTestCase): + """Test model args creation functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + async def test_create_model_args_from_config_model(self): + """Test model args creation when model is in config.""" + benchmark_config = MagicMock() + benchmark_config.model = "test-model" + benchmark_config.eval_candidate = None + benchmark_config.model_args = [] + + model_args = self.provider._create_model_args("http://test", benchmark_config) + + expected = { + "model": "test-model", + "base_url": "http://test", + "num_concurrent": "1", + "max_retries": "3", + } + assert model_args == expected + + async def test_create_model_args_from_eval_candidate(self): + """Test model args creation when model is in eval_candidate.""" + benchmark_config = MagicMock() + benchmark_config.model = None + + eval_candidate = MagicMock() + eval_candidate.model = "candidate-model" + benchmark_config.eval_candidate = eval_candidate + benchmark_config.model_args = [] + + model_args = self.provider._create_model_args("http://test", benchmark_config) + + expected = { + "model": "candidate-model", + "base_url": "http://test", + "num_concurrent": "1", + "max_retries": "3", + } + assert model_args == expected + + async def test_create_model_args_with_additional_args(self): + """Test model args creation with additional model_args.""" + benchmark_config = MagicMock() + benchmark_config.model = "test-model" + benchmark_config.eval_candidate = None + + # Mock model_args as a list of objects with name and value attributes + model_arg1 = MagicMock() + model_arg1.name = "temperature" + model_arg1.value = "0.7" + model_arg2 = MagicMock() + model_arg2.name = "max_tokens" + model_arg2.value = "100" + + benchmark_config.model_args = [model_arg1, model_arg2] + + model_args = self.provider._create_model_args("http://test", benchmark_config) + + expected = { + "model": "test-model", + "base_url": "http://test", + "num_concurrent": "1", + "max_retries": "3", + "temperature": "0.7", + "max_tokens": "100", + } + assert model_args == expected + + async def test_create_model_args_override_defaults(self): + """Test that custom model_args can override defaults.""" + benchmark_config = MagicMock() + benchmark_config.model = "test-model" + benchmark_config.eval_candidate = None + + # Override default num_concurrent + model_arg = MagicMock() + model_arg.name = "num_concurrent" + model_arg.value = "5" + + benchmark_config.model_args = [model_arg] + + model_args = self.provider._create_model_args("http://test", benchmark_config) + + expected = { + "model": "test-model", + "base_url": "http://test", + "num_concurrent": "5", # Should be overridden + "max_retries": "3", + } + assert model_args == expected + + +class TestLMEvalArgsCollection(unittest.IsolatedAsyncioTestCase): + """Test LMEval args collection functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + async def test_collect_lmeval_args_from_task_config(self): + """Test collecting lmeval_args from task config.""" + task_config = MagicMock() + task_config.lmeval_args = {"arg1": "value1", "arg2": "value2"} + task_config.metadata = None + + stored_benchmark = None + + result = self.provider._collect_lmeval_args(task_config, stored_benchmark) + + expected = {"arg1": "value1", "arg2": "value2"} + assert result == expected + + async def test_collect_lmeval_args_from_stored_benchmark(self): + """Test collecting lmeval_args from stored benchmark metadata.""" + task_config = MagicMock() + task_config.lmeval_args = None + task_config.metadata = None + + stored_benchmark = MagicMock() + stored_benchmark.metadata = { + "lmeval_args": {"bench_arg1": "bench_value1", "bench_arg2": "bench_value2"} + } + + result = self.provider._collect_lmeval_args(task_config, stored_benchmark) + + expected = {"bench_arg1": "bench_value1", "bench_arg2": "bench_value2"} + assert result == expected + + async def test_collect_lmeval_args_combined(self): + """Test collecting lmeval_args from multiple sources with precedence.""" + task_config = MagicMock() + task_config.lmeval_args = {"arg1": "task_value1"} + task_config.metadata = None + + stored_benchmark = MagicMock() + stored_benchmark.metadata = { + "lmeval_args": {"arg1": "bench_value1", "arg2": "bench_value2"} + } + + result = self.provider._collect_lmeval_args(task_config, stored_benchmark) + + # Stored benchmark args should override task config args + expected = {"arg1": "bench_value1", "arg2": "bench_value2"} + assert result == expected + + +class TestCommandBuilding(unittest.IsolatedAsyncioTestCase): + """Test command building functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = LMEvalEvalProviderConfig(use_k8s=False) + self.provider = LMEvalInline(self.config, deps={Api.files: create_mock_files_api()}) + + async def test_build_command_basic(self): + """Test building basic command.""" + task_config = MagicMock() + eval_candidate = MagicMock() + eval_candidate.type = "model" + eval_candidate.model = "test-model" + task_config.eval_candidate = eval_candidate + task_config.model = "test-model" + task_config.model_args = [] + task_config.metadata = {} + + stored_benchmark = None + + cmd = self.provider.build_command( + task_config=task_config, + benchmark_id="lmeval::mmlu", + limit="10", + stored_benchmark=stored_benchmark, + job_output_results_dir=Path("/tmp"), + job_uuid="test_job_uuid" + ) + cmd_str = " ".join(cmd) + assert "lm_eval" in cmd_str + assert "--model" in cmd_str + assert "local-completions" in cmd_str + assert "--model_args" in cmd_str + assert "--tasks" in cmd_str + assert "mmlu" in cmd_str + assert "--limit" in cmd_str + assert "10" in cmd_str + + async def test_build_command_with_tokenizer_metadata(self): + """Test building command with tokenizer in stored benchmark metadata.""" + task_config = MagicMock() + eval_candidate = MagicMock() + eval_candidate.type = "model" + eval_candidate.model = "test-model" + task_config.eval_candidate = eval_candidate + task_config.model = "test-model" + task_config.model_args = [] + task_config.metadata = {} + + stored_benchmark = MagicMock() + stored_benchmark.metadata = { + "tokenizer": "custom-tokenizer", + "tokenized_requests": True + } + + cmd = self.provider.build_command( + task_config=task_config, + benchmark_id="lmeval::mmlu", + limit="10", + stored_benchmark=stored_benchmark, + job_output_results_dir=Path("/tmp"), + job_uuid="test_job_uuid" + ) + + cmd_str = " ".join(cmd) + assert "tokenizer=custom-tokenizer" in cmd_str + assert "tokenized_requests=True" in cmd_str + + async def test_build_command_with_lmeval_args(self): + """Test building command with lmeval_args.""" + task_config = MagicMock() + eval_candidate = MagicMock() + eval_candidate.type = "model" + eval_candidate.model = "test-model" + task_config.eval_candidate = eval_candidate + task_config.model = "test-model" + task_config.model_args = [] + task_config.metadata = {} + task_config.lmeval_args = {"output_path": "/tmp/results"} + + stored_benchmark = None + + cmd = self.provider.build_command( + task_config=task_config, + benchmark_id="lmeval::mmlu", + limit="10", + stored_benchmark=stored_benchmark, + job_output_results_dir=Path("/tmp"), + job_uuid="test_job_uuid" + ) + cmd_str = " ".join(cmd) + assert "output_path" in cmd_str + assert "/tmp/results" in cmd_str + + async def test_build_command_invalid_eval_candidate(self): + """Test building command with invalid eval candidate type.""" + task_config = MagicMock() + eval_candidate = MagicMock() + eval_candidate.type = "dataset" # Invalid type + task_config.eval_candidate = eval_candidate + + stored_benchmark = None + + with self.assertRaises(LMEvalConfigError): + self.provider.build_command( + task_config=task_config, + benchmark_id="lmeval::mmlu", + limit="10", + stored_benchmark=stored_benchmark, + job_output_results_dir=Path("/tmp"), + job_uuid="test_job_uuid" + ) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_namespace.py b/tests/test_namespace.py index ebd2f2e..194a73d 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -7,7 +7,7 @@ from llama_stack_provider_lmeval.config import LMEvalEvalProviderConfig from llama_stack_provider_lmeval.errors import LMEvalConfigError -from llama_stack_provider_lmeval.lmeval import _resolve_namespace +from llama_stack_provider_lmeval.remote.lmeval import _resolve_namespace class TestNamespaceResolution(unittest.TestCase): @@ -17,7 +17,7 @@ def setUp(self): """Set up test fixtures.""" env_vars_to_clear = [ 'TRUSTYAI_LM_EVAL_NAMESPACE', - 'POD_NAMESPACE', + 'POD_NAMESPACE', 'NAMESPACE' ] for var in env_vars_to_clear: @@ -38,28 +38,28 @@ def tearDown(self): def test_namespace_from_provider_config(self): """Test namespace resolution from provider config.""" config = LMEvalEvalProviderConfig(namespace="test-namespace") - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "test-namespace") mock_logger.debug.assert_called_with("Using namespace from provider config: %s", "test-namespace") def test_namespace_respects_any_config_value(self): """Test that any namespace value in provider config is respected.""" config = LMEvalEvalProviderConfig(namespace="default") - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "default") mock_logger.debug.assert_called_with("Using namespace from provider config: %s", "default") - + config_test = LMEvalEvalProviderConfig(namespace="test") - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config_test) - + self.assertEqual(namespace, "test") mock_logger.debug.assert_called_with("Using namespace from provider config: %s", "test") @@ -67,120 +67,120 @@ def test_namespace_from_trustyai_env_var(self): """Test namespace resolution from TRUSTYAI_LM_EVAL_NAMESPACE environment variable.""" config = LMEvalEvalProviderConfig() os.environ['TRUSTYAI_LM_EVAL_NAMESPACE'] = 'trustyai-namespace' - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "trustyai-namespace") mock_logger.debug.assert_called_with("Using namespace from environment variable: %s", "trustyai-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_from_service_account_file(self, mock_path): """Test namespace resolution from service account file.""" config = LMEvalEvalProviderConfig() - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = True - + with patch('builtins.open', mock_open(read_data='service-account-namespace')): - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "service-account-namespace") mock_logger.debug.assert_called_with("Using namespace from service account: %s", "service-account-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_from_empty_service_account_file(self, mock_path): """Test namespace resolution when service account file is empty or whitespace.""" config = LMEvalEvalProviderConfig() os.environ['TRUSTYAI_LM_EVAL_NAMESPACE'] = 'trustyai-namespace' - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = True - + with patch('builtins.open', mock_open(read_data='')): - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "trustyai-namespace") mock_logger.debug.assert_called_with("Using namespace from environment variable: %s", "trustyai-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_from_whitespace_service_account_file(self, mock_path): """Test namespace resolution when service account file contains only whitespace.""" - + config = LMEvalEvalProviderConfig() os.environ['POD_NAMESPACE'] = 'pod-namespace' - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = True - + with patch('builtins.open', mock_open(read_data=' \n\t ')): - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "pod-namespace") mock_logger.debug.assert_called_with("Using namespace from POD_NAMESPACE environment variable: %s", "pod-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_from_pod_namespace_env_var(self, mock_path): """Test namespace resolution from POD_NAMESPACE environment variable.""" config = LMEvalEvalProviderConfig() os.environ['POD_NAMESPACE'] = 'pod-namespace' - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = False - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "pod-namespace") mock_logger.debug.assert_called_with("Using namespace from POD_NAMESPACE environment variable: %s", "pod-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_from_namespace_env_var(self, mock_path): """Test namespace resolution from NAMESPACE environment variable.""" config = LMEvalEvalProviderConfig() os.environ['NAMESPACE'] = 'generic-namespace' - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = False - - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "generic-namespace") mock_logger.debug.assert_called_with("Using namespace from NAMESPACE environment variable: %s", "generic-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_resolution_priority(self, mock_path): """Test that namespace resolution follows right order.""" config = LMEvalEvalProviderConfig(namespace="config-namespace") os.environ['TRUSTYAI_LM_EVAL_NAMESPACE'] = 'trustyai-namespace' os.environ['POD_NAMESPACE'] = 'pod-namespace' - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = True - + with patch('builtins.open', mock_open(read_data='service-account-namespace')): - with patch('llama_stack_provider_lmeval.lmeval.logger') as mock_logger: + with patch('llama_stack_provider_lmeval.remote.lmeval.logger') as mock_logger: namespace = _resolve_namespace(config) - + self.assertEqual(namespace, "config-namespace") mock_logger.debug.assert_called_with("Using namespace from provider config: %s", "config-namespace") - @patch('llama_stack_provider_lmeval.lmeval.Path') + @patch('llama_stack_provider_lmeval.remote.lmeval.Path') def test_namespace_resolution_failure(self, mock_path): """Test that function raises exception when no namespace is found.""" config = LMEvalEvalProviderConfig() - + mock_path_instance = mock_path.return_value mock_path_instance.exists.return_value = False - + with self.assertRaises(LMEvalConfigError) as context: _resolve_namespace(config) - + error_msg = str(context.exception) self.assertIn("Unable to determine namespace", error_msg) self.assertIn("Set 'namespace' in your run.yaml provider config", error_msg) @@ -188,4 +188,4 @@ def test_namespace_resolution_failure(self, mock_path): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() \ No newline at end of file