Skip to content

Commit 216d44a

Browse files
minor refactor and handle nans
1 parent c7980e0 commit 216d44a

3 files changed

Lines changed: 91 additions & 19 deletions

File tree

project/paperbench/paperbench/nano/eval.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import json
33
import os
44
import time
5+
from collections.abc import AsyncGenerator, Sequence
56
from contextlib import asynccontextmanager
6-
from typing import Any, AsyncGenerator, Literal, Sequence
7+
from typing import Any, Literal
78

89
import blobfile as bf
9-
import numpy as np
1010
import structlog.stdlib
1111
from dotenv import load_dotenv
1212

@@ -48,8 +48,13 @@
4848
get_timestamp,
4949
is_docker_running,
5050
purple,
51+
safe_mean,
5152
)
5253

54+
MIN_UPLOAD_INTERVAL_MESSAGES = 5
55+
MIN_UPLOAD_INTERVAL_SECONDS = 1800
56+
57+
5358
GRADER_OPENAI_API_KEY = os.getenv("GRADER_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
5459

5560

@@ -64,7 +69,7 @@ class ExternalPythonCodingSolver(PythonCodingSolver):
6469
default=False, doc="Whether to make the local NVIDIA GPU available to the agent"
6570
)
6671

67-
upload_interval_messages: int = chz.field(
72+
upload_interval_messages: int | None = chz.field(
6873
default=None,
6974
doc="Upload interval in agent steps for heavy logs",
7075
)
@@ -123,16 +128,23 @@ async def _start_computer(self, task: PBTask) -> AsyncGenerator[ComputerInterfac
123128
destinations=["run"],
124129
)
125130

126-
if self.upload_interval_messages and self.upload_interval_messages <= 5:
131+
if (
132+
self.upload_interval_messages
133+
and self.upload_interval_messages <= MIN_UPLOAD_INTERVAL_MESSAGES
134+
):
127135
ctx_logger.warning(
128136
"Uploading artifacts every five messages or less is untested. "
129137
"Consider setting `upload_interval_messages` to a higher value.",
130138
destinations=["run"],
131139
)
132140

133-
if self.upload_interval_seconds and self.upload_interval_seconds < 1800:
141+
if (
142+
self.upload_interval_seconds
143+
and self.upload_interval_seconds < MIN_UPLOAD_INTERVAL_SECONDS
144+
):
134145
ctx_logger.warning(
135-
"Uploading artifacts more frequently than every 1800 seconds is untested. "
146+
"Uploading artifacts more frequently than every"
147+
f" {MIN_UPLOAD_INTERVAL_SECONDS} seconds is untested. "
136148
"Consider setting `upload_interval_seconds` to a higher value.",
137149
destinations=["run"],
138150
)
@@ -257,7 +269,7 @@ async def _run_agent(self, computer: ComputerInterface, task: PBTask) -> AgentOu
257269
)
258270

259271
ctx_logger.info(f"status: {status}", destinations=["run"])
260-
start_time = status.get("created_at") if status.get("created_at") else time.time()
272+
start_time = status.get("created_at", time.time())
261273
if status.get("agent_finished_at"):
262274
end_time = status.get("agent_finished_at")
263275
elif status.get("last_updated"):
@@ -572,7 +584,7 @@ async def get_full_summary(
572584
}
573585

574586
other_stats = {
575-
"repro_mean_time": np.mean(
587+
"repro_mean_time": safe_mean(
576588
[
577589
r.reproduction_output.metadata.repro_execution_time # type: ignore
578590
for r in results_clean
@@ -636,16 +648,7 @@ async def get_existing_run_ids(self, run_group_id: str) -> set[str]:
636648
return run_ids
637649

638650
def uses_local_config(self) -> bool:
639-
"""
640-
Check if any of paperbench.solver.cluster_config, paperbench.reproduction.cluster_config,
641-
or paperbench.judge.cluster_config is an instance of LocalConfig.
642-
643-
Args:
644-
paperbench: A PaperBench PythonCodingEval instance
645-
646-
Returns:
647-
bool: True if any of the cluster configs is a LocalConfig, False otherwise
648-
"""
651+
"""Return True if any cluster config uses LocalConfig."""
649652

650653
# PythonCodingSolver may not have a cluster_config, just ExternalPythonCodingSolver does for now
651654
if hasattr(self.solver, "cluster_config"):

project/paperbench/paperbench/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import time
77
import uuid
88
from pathlib import Path
9-
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar
9+
from typing import Any, Awaitable, Callable, ParamSpec, Sequence, TypeVar
1010

1111
import blobfile as bf
12+
import numpy as np
1213
import openai
1314
import structlog.stdlib
1415
import tenacity
@@ -95,6 +96,25 @@ def get_timestamp() -> str:
9596
return time.strftime("%Y-%m-%dT%H-%M-%S-%Z", time.gmtime())
9697

9798

99+
def safe_mean(values: Sequence[float | int], default: float = np.nan) -> float:
100+
"""Return the mean or a default when no values are provided."""
101+
102+
assert isinstance(values, Sequence), "`values` must be a sequence"
103+
assert all(isinstance(v, (int, float)) for v in values), "`values` must be numeric list"
104+
assert isinstance(default, (int, float)), "`default` must be numeric"
105+
106+
if not values:
107+
return float(default)
108+
109+
result = float(np.mean(values))
110+
111+
assert isinstance(result, (int, float)), (
112+
f"Expected the mean to be a number, but got `{type(result)}`"
113+
)
114+
115+
return result
116+
117+
98118
def get_commit_hash() -> str:
99119
"""Returns the current Git commit hash."""
100120

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import math
2+
3+
import pytest
4+
5+
from paperbench.utils import safe_mean
6+
7+
8+
def test_safe_mean_with_values() -> None:
9+
# Given
10+
values = [1.0, 3.0]
11+
expected = 2.0
12+
13+
# When
14+
actual = safe_mean(values)
15+
16+
# Then
17+
assert math.isclose(actual, expected)
18+
19+
20+
def test_safe_mean_empty_default() -> None:
21+
# Given
22+
values: list[float] = []
23+
expected = 5.0
24+
25+
# When
26+
actual = safe_mean(values, default=expected)
27+
28+
# Then
29+
assert actual == expected
30+
31+
32+
def test_safe_mean_with_nan() -> None:
33+
# Given
34+
values = [float("nan"), 1.0]
35+
36+
# When
37+
actual = safe_mean(values)
38+
39+
# Then
40+
assert math.isnan(actual)
41+
42+
43+
def test_safe_mean_invalid_values() -> None:
44+
# Given
45+
values = [1.0, "bad"]
46+
47+
# When/Then
48+
with pytest.raises(AssertionError):
49+
safe_mean(values) # type: ignore[arg-type]

0 commit comments

Comments
 (0)