Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions project/paperbench/paperbench/nano/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json
import os
import time
from collections.abc import AsyncGenerator, Sequence
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Literal, Sequence
from typing import Any, Literal

import blobfile as bf
import numpy as np
import structlog.stdlib
from dotenv import load_dotenv

Expand Down Expand Up @@ -48,8 +48,13 @@
get_timestamp,
is_docker_running,
purple,
safe_mean,
)

MIN_UPLOAD_INTERVAL_MESSAGES = 5
MIN_UPLOAD_INTERVAL_SECONDS = 1800


GRADER_OPENAI_API_KEY = os.getenv("GRADER_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")


Expand All @@ -64,7 +69,7 @@ class ExternalPythonCodingSolver(PythonCodingSolver):
default=False, doc="Whether to make the local NVIDIA GPU available to the agent"
)

upload_interval_messages: int = chz.field(
upload_interval_messages: int | None = chz.field(
default=None,
doc="Upload interval in agent steps for heavy logs",
)
Expand Down Expand Up @@ -123,16 +128,23 @@ async def _start_computer(self, task: PBTask) -> AsyncGenerator[ComputerInterfac
destinations=["run"],
)

if self.upload_interval_messages and self.upload_interval_messages <= 5:
if (
self.upload_interval_messages
and self.upload_interval_messages <= MIN_UPLOAD_INTERVAL_MESSAGES
):
ctx_logger.warning(
"Uploading artifacts every five messages or less is untested. "
"Consider setting `upload_interval_messages` to a higher value.",
destinations=["run"],
)

if self.upload_interval_seconds and self.upload_interval_seconds < 1800:
if (
self.upload_interval_seconds
and self.upload_interval_seconds < MIN_UPLOAD_INTERVAL_SECONDS
):
ctx_logger.warning(
"Uploading artifacts more frequently than every 1800 seconds is untested. "
"Uploading artifacts more frequently than every"
f" {MIN_UPLOAD_INTERVAL_SECONDS} seconds is untested. "
"Consider setting `upload_interval_seconds` to a higher value.",
destinations=["run"],
)
Expand Down Expand Up @@ -257,7 +269,7 @@ async def _run_agent(self, computer: ComputerInterface, task: PBTask) -> AgentOu
)

ctx_logger.info(f"status: {status}", destinations=["run"])
start_time = status.get("created_at") if status.get("created_at") else time.time()
start_time = status.get("created_at", time.time())
if status.get("agent_finished_at"):
end_time = status.get("agent_finished_at")
elif status.get("last_updated"):
Expand Down Expand Up @@ -572,7 +584,7 @@ async def get_full_summary(
}

other_stats = {
"repro_mean_time": np.mean(
"repro_mean_time": safe_mean(
[
r.reproduction_output.metadata.repro_execution_time # type: ignore
for r in results_clean
Expand Down Expand Up @@ -636,16 +648,7 @@ async def get_existing_run_ids(self, run_group_id: str) -> set[str]:
return run_ids

def uses_local_config(self) -> bool:
"""
Check if any of paperbench.solver.cluster_config, paperbench.reproduction.cluster_config,
or paperbench.judge.cluster_config is an instance of LocalConfig.

Args:
paperbench: A PaperBench PythonCodingEval instance

Returns:
bool: True if any of the cluster configs is a LocalConfig, False otherwise
"""
"""Return True if any cluster config uses LocalConfig."""

# PythonCodingSolver may not have a cluster_config, just ExternalPythonCodingSolver does for now
if hasattr(self.solver, "cluster_config"):
Expand Down
22 changes: 21 additions & 1 deletion project/paperbench/paperbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import time
import uuid
from pathlib import Path
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar
from typing import Any, Awaitable, Callable, ParamSpec, Sequence, TypeVar

import blobfile as bf
import numpy as np
import openai
import structlog.stdlib
import tenacity
Expand Down Expand Up @@ -95,6 +96,25 @@ def get_timestamp() -> str:
return time.strftime("%Y-%m-%dT%H-%M-%S-%Z", time.gmtime())


def safe_mean(values: Sequence[float | int], default: float = np.nan) -> float:
"""Return the mean or a default when no values are provided."""

assert isinstance(values, Sequence), "`values` must be a sequence"
assert all(isinstance(v, (int, float)) for v in values), "`values` must be numeric list"
assert isinstance(default, (int, float)), "`default` must be numeric"

if not values:
return float(default)

result = float(np.mean(values))

assert isinstance(result, (int, float)), (
f"Expected the mean to be a number, but got `{type(result)}`"
)

return result


def get_commit_hash() -> str:
"""Returns the current Git commit hash."""

Expand Down
49 changes: 49 additions & 0 deletions project/paperbench/tests/unit/test_eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import math

import pytest

from paperbench.utils import safe_mean


def test_safe_mean_with_values() -> None:
# Given
values = [1.0, 3.0]
expected = 2.0

# When
actual = safe_mean(values)

# Then
assert math.isclose(actual, expected)


def test_safe_mean_empty_default() -> None:
# Given
values: list[float] = []
expected = 5.0

# When
actual = safe_mean(values, default=expected)

# Then
assert actual == expected


def test_safe_mean_with_nan() -> None:
# Given
values = [float("nan"), 1.0]

# When
actual = safe_mean(values)

# Then
assert math.isnan(actual)


def test_safe_mean_invalid_values() -> None:
# Given
values = [1.0, "bad"]

# When/Then
with pytest.raises(AssertionError):
safe_mean(values) # type: ignore[arg-type]