diff --git a/project/paperbench/paperbench/nano/eval.py b/project/paperbench/paperbench/nano/eval.py index e7f052b2..6057c58a 100644 --- a/project/paperbench/paperbench/nano/eval.py +++ b/project/paperbench/paperbench/nano/eval.py @@ -2,11 +2,11 @@ import json import os import time +from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Literal, Sequence +from typing import Any, Literal import blobfile as bf -import numpy as np import structlog.stdlib from dotenv import load_dotenv @@ -48,8 +48,13 @@ get_timestamp, is_docker_running, purple, + safe_mean, ) +MIN_UPLOAD_INTERVAL_MESSAGES = 5 +MIN_UPLOAD_INTERVAL_SECONDS = 1800 + + GRADER_OPENAI_API_KEY = os.getenv("GRADER_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") @@ -64,7 +69,7 @@ class ExternalPythonCodingSolver(PythonCodingSolver): default=False, doc="Whether to make the local NVIDIA GPU available to the agent" ) - upload_interval_messages: int = chz.field( + upload_interval_messages: int | None = chz.field( default=None, doc="Upload interval in agent steps for heavy logs", ) @@ -123,16 +128,23 @@ async def _start_computer(self, task: PBTask) -> AsyncGenerator[ComputerInterfac destinations=["run"], ) - if self.upload_interval_messages and self.upload_interval_messages <= 5: + if ( + self.upload_interval_messages + and self.upload_interval_messages <= MIN_UPLOAD_INTERVAL_MESSAGES + ): ctx_logger.warning( "Uploading artifacts every five messages or less is untested. " "Consider setting `upload_interval_messages` to a higher value.", destinations=["run"], ) - if self.upload_interval_seconds and self.upload_interval_seconds < 1800: + if ( + self.upload_interval_seconds + and self.upload_interval_seconds < MIN_UPLOAD_INTERVAL_SECONDS + ): ctx_logger.warning( - "Uploading artifacts more frequently than every 1800 seconds is untested. " + "Uploading artifacts more frequently than every" + f" {MIN_UPLOAD_INTERVAL_SECONDS} seconds is untested. " "Consider setting `upload_interval_seconds` to a higher value.", destinations=["run"], ) @@ -257,7 +269,7 @@ async def _run_agent(self, computer: ComputerInterface, task: PBTask) -> AgentOu ) ctx_logger.info(f"status: {status}", destinations=["run"]) - start_time = status.get("created_at") if status.get("created_at") else time.time() + start_time = status.get("created_at", time.time()) if status.get("agent_finished_at"): end_time = status.get("agent_finished_at") elif status.get("last_updated"): @@ -572,7 +584,7 @@ async def get_full_summary( } other_stats = { - "repro_mean_time": np.mean( + "repro_mean_time": safe_mean( [ r.reproduction_output.metadata.repro_execution_time # type: ignore for r in results_clean @@ -636,16 +648,7 @@ async def get_existing_run_ids(self, run_group_id: str) -> set[str]: return run_ids def uses_local_config(self) -> bool: - """ - Check if any of paperbench.solver.cluster_config, paperbench.reproduction.cluster_config, - or paperbench.judge.cluster_config is an instance of LocalConfig. - - Args: - paperbench: A PaperBench PythonCodingEval instance - - Returns: - bool: True if any of the cluster configs is a LocalConfig, False otherwise - """ + """Return True if any cluster config uses LocalConfig.""" # PythonCodingSolver may not have a cluster_config, just ExternalPythonCodingSolver does for now if hasattr(self.solver, "cluster_config"): diff --git a/project/paperbench/paperbench/utils.py b/project/paperbench/paperbench/utils.py index c9e13b93..cc5e1073 100644 --- a/project/paperbench/paperbench/utils.py +++ b/project/paperbench/paperbench/utils.py @@ -6,9 +6,10 @@ import time import uuid from pathlib import Path -from typing import Any, Awaitable, Callable, ParamSpec, TypeVar +from typing import Any, Awaitable, Callable, ParamSpec, Sequence, TypeVar import blobfile as bf +import numpy as np import openai import structlog.stdlib import tenacity @@ -95,6 +96,25 @@ def get_timestamp() -> str: return time.strftime("%Y-%m-%dT%H-%M-%S-%Z", time.gmtime()) +def safe_mean(values: Sequence[float | int], default: float = np.nan) -> float: + """Return the mean or a default when no values are provided.""" + + assert isinstance(values, Sequence), "`values` must be a sequence" + assert all(isinstance(v, (int, float)) for v in values), "`values` must be numeric list" + assert isinstance(default, (int, float)), "`default` must be numeric" + + if not values: + return float(default) + + result = float(np.mean(values)) + + assert isinstance(result, (int, float)), ( + f"Expected the mean to be a number, but got `{type(result)}`" + ) + + return result + + def get_commit_hash() -> str: """Returns the current Git commit hash.""" diff --git a/project/paperbench/tests/unit/test_eval_utils.py b/project/paperbench/tests/unit/test_eval_utils.py new file mode 100644 index 00000000..649b1d4a --- /dev/null +++ b/project/paperbench/tests/unit/test_eval_utils.py @@ -0,0 +1,49 @@ +import math + +import pytest + +from paperbench.utils import safe_mean + + +def test_safe_mean_with_values() -> None: + # Given + values = [1.0, 3.0] + expected = 2.0 + + # When + actual = safe_mean(values) + + # Then + assert math.isclose(actual, expected) + + +def test_safe_mean_empty_default() -> None: + # Given + values: list[float] = [] + expected = 5.0 + + # When + actual = safe_mean(values, default=expected) + + # Then + assert actual == expected + + +def test_safe_mean_with_nan() -> None: + # Given + values = [float("nan"), 1.0] + + # When + actual = safe_mean(values) + + # Then + assert math.isnan(actual) + + +def test_safe_mean_invalid_values() -> None: + # Given + values = [1.0, "bad"] + + # When/Then + with pytest.raises(AssertionError): + safe_mean(values) # type: ignore[arg-type]