Skip to content

Commit 60644a9

Browse files
Fix failing tests
1 parent 0d61839 commit 60644a9

File tree

3 files changed

+90
-184
lines changed

3 files changed

+90
-184
lines changed

src/llama_stack_provider_lmeval/inline/lmeval.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from pathlib import Path
1111
from typing import Any
1212

13-
from llama_stack.apis.datatypes import Api
14-
from llama_stack.apis.files import OpenAIFileObject, OpenAIFilePurpose, UploadFile
1513
from llama_stack.apis.benchmarks import Benchmark, ListBenchmarksResponse
1614
from llama_stack.apis.common.job_types import Job, JobStatus
15+
from llama_stack.apis.datatypes import Api
1716
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse
17+
from llama_stack.apis.files import OpenAIFileObject, OpenAIFilePurpose, UploadFile
1818
from llama_stack.apis.scoring import ScoringResult
1919
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
2020

@@ -34,10 +34,12 @@ def __init__(
3434
self.benchmarks: dict[str, Benchmark] = {}
3535
self._jobs: list[Job] = []
3636
self._job_metadata: dict[str, dict[str, str]] = {}
37-
self.files_api = deps.get(Api.files)
37+
self.files_api = deps.get(Api.files) if deps else None
3838

3939
async def initialize(self):
4040
"Initialize the LMEval Inline provider"
41+
if not self.files_api:
42+
raise LMEvalConfigError("Files API is not initialized")
4143

4244
async def list_benchmarks(self) -> ListBenchmarksResponse:
4345
"""List all registered benchmarks."""
@@ -59,7 +61,6 @@ def _get_job_id(self) -> str:
5961
async def run_eval(
6062
self, benchmark_id: str, benchmark_config: BenchmarkConfig, limit="2"
6163
) -> Job:
62-
6364
if not isinstance(benchmark_config, BenchmarkConfig):
6465
raise LMEvalConfigError("LMEval requires BenchmarkConfig")
6566

@@ -109,7 +110,7 @@ async def run_eval(
109110
env=env,
110111
)
111112

112-
self._job_metadata[job_id]["process_id"] = process.pid
113+
self._job_metadata[job_id]["process_id"] = str(process.pid)
113114

114115
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300)
115116

@@ -150,9 +151,9 @@ async def run_eval(
150151
)
151152

152153
if upload_job_result:
153-
self._job_metadata[job_id][
154-
"uploaded_file"
155-
] = upload_job_result.id
154+
self._job_metadata[job_id]["uploaded_file"] = (
155+
upload_job_result.id
156+
)
156157
logger.info(
157158
"Uploaded job result file %s to Files API with ID: %s",
158159
actual_result_file,
@@ -170,9 +171,9 @@ async def run_eval(
170171
"Failed to process results file for job %s: %s", job_id, e
171172
)
172173
job.status = JobStatus.failed
173-
self._job_metadata[job_id][
174-
"error"
175-
] = f"Failed to process results: {str(e)}"
174+
self._job_metadata[job_id]["error"] = (
175+
f"Failed to process results: {str(e)}"
176+
)
176177
else:
177178
logger.warning(
178179
"No results files found for job %s in directory %s",
@@ -189,9 +190,7 @@ async def run_eval(
189190
logger.error("stdout: %s", stdout.decode("utf-8") if stdout else "")
190191
logger.error("stderr: %s", stderr.decode("utf-8") if stderr else "")
191192
job.status = JobStatus.failed
192-
self._job_metadata[job_id][
193-
"error"
194-
] = f"""
193+
self._job_metadata[job_id]["error"] = f"""
195194
Process failed with return code {process.returncode}
196195
"""
197196
except Exception as e:
@@ -209,6 +208,10 @@ async def run_eval(
209208
async def _upload_file(
210209
self, file: Path, purpose: OpenAIFilePurpose
211210
) -> OpenAIFileObject | None:
211+
if self.files_api is None:
212+
logger.warning("Files API not available, cannot upload file %s", file)
213+
return None
214+
212215
if file.exists():
213216
with open(file, "rb") as f:
214217
upload_file = await self.files_api.openai_upload_file(
@@ -239,7 +242,7 @@ async def _parse_job_results_from_data(
239242
if isinstance(task_results, dict):
240243
# Extract metric scores
241244
for metric_name, metric_value in task_results.items():
242-
if isinstance(metric_value, (int, float)):
245+
if isinstance(metric_value, int | float):
243246
score_key = f"{task_name}:{metric_name}"
244247
scores[score_key] = ScoringResult(
245248
aggregated_results={
@@ -278,7 +281,6 @@ async def _parse_job_results_from_data(
278281
return EvaluateResponse(generations=[], scores={})
279282

280283
def _create_model_args(self, base_url: str, benchmark_config: BenchmarkConfig):
281-
282284
model_args = {"model": None, "base_url": base_url}
283285

284286
model_name = None
@@ -333,6 +335,29 @@ def _collect_lmeval_args(
333335

334336
return lmeval_args
335337

338+
def _extract_task_name(self, benchmark_id: str) -> str:
339+
"""Extract task name from benchmark ID.
340+
341+
Args:
342+
benchmark_id: The benchmark identifier
343+
344+
Returns:
345+
Task name
346+
347+
Raises:
348+
LMEvalTaskNameError: If task name is empty or invalid
349+
"""
350+
from ..errors import LMEvalTaskNameError
351+
352+
task_name_parts = benchmark_id.split("::")
353+
task_name = task_name_parts[-1].strip() if task_name_parts else ""
354+
if not task_name:
355+
raise LMEvalTaskNameError(
356+
f"Invalid benchmark_id '{benchmark_id}': task name is empty or invalid"
357+
)
358+
359+
return task_name
360+
336361
def build_command(
337362
self,
338363
task_config: BenchmarkConfig,
@@ -410,10 +435,7 @@ def build_command(
410435
cmd.extend(["--model_args", ",".join(model_args_list)])
411436

412437
# Extract task name from benchmark_id (remove provider prefix)
413-
# benchmark_id format: "inline::trustyai_lmeval::task_name"
414-
task_name = (
415-
benchmark_id.split("::")[-1] if "::" in benchmark_id else benchmark_id
416-
)
438+
task_name = self._extract_task_name(benchmark_id)
417439
cmd.extend(["--tasks", task_name])
418440

419441
cmd.extend(["--limit", limit])
@@ -475,9 +497,9 @@ async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
475497
return
476498

477499
if job.status in [JobStatus.in_progress, JobStatus.scheduled]:
478-
process_id = self._job_metadata.get(job_id, {}).get("process_id")
479-
if process_id:
480-
process_id = int(process_id)
500+
process_id_str = self._job_metadata.get(job_id, {}).get("process_id")
501+
if process_id_str:
502+
process_id = int(process_id_str)
481503
logger.info("Attempting to cancel subprocess %s", process_id)
482504

483505
try:
@@ -573,7 +595,7 @@ async def shutdown(self) -> None:
573595
self.benchmarks.clear()
574596

575597
# Close files API connection if it exists and has cleanup methods
576-
if self.files_api and hasattr(self.files_api, 'close'):
598+
if self.files_api and hasattr(self.files_api, "close"):
577599
try:
578600
await self.files_api.close()
579601
logger.debug("Closed Files API connection")

src/llama_stack_provider_lmeval/remote/provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
remote_provider_spec,
66
)
77

8+
89
def get_provider_spec() -> ProviderSpec:
910
return remote_provider_spec(
1011
api=Api.eval,

0 commit comments

Comments
 (0)