Skip to content

Commit efbf04e

Browse files
cleanup the reproduce flow
1 parent 109fc4c commit efbf04e

8 files changed

Lines changed: 200 additions & 226 deletions

File tree

project/paperbench/paperbench/agents/registry.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,19 @@ def get_agent(self, agent_id: str) -> Agent:
108108

109109

110110
registry = AgentRegistry()
111+
112+
113+
def get_agents_env_vars(registry: AgentRegistry) -> dict[str, str]:
114+
"""Parses agent.env txt file of KEY=VALUE into a dictionary"""
115+
agent_env_path = registry.get_agents_dir() / "agent.env"
116+
117+
if not agent_env_path.exists():
118+
logger.warning(f"agent.env not found in {agent_env_path}")
119+
return {}
120+
env_vars = {}
121+
with open(agent_env_path, "r") as f:
122+
for line in f:
123+
if line.strip() and not line.startswith("#"):
124+
key, value = line.strip().split("=", 1)
125+
env_vars[key] = value
126+
return env_vars

project/paperbench/paperbench/nano/eval.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ async def run(self, task: ComputerTask) -> AsyncGenerator[Step | FinalResult, No
193193
code_only=task.judge.code_only,
194194
resources_provided=task.judge.resources_provided,
195195
judge_output=None,
196-
reproduction_output=None,
196+
reproduction_metadata=None,
197197
monitor_result=grade.paperbench_result.monitor_result,
198198
monitor_ran=grade.paperbench_result.monitor_ran,
199199
),
@@ -569,11 +569,7 @@ async def get_full_summary(
569569
[r for r in results_clean if not r.agent_output or not r.submission_exists]
570570
),
571571
"n_reproductions_failed": len(
572-
[
573-
r
574-
for r in results_clean
575-
if not r.reproduction_output or not r.reproduction_output.success
576-
]
572+
[r for r in results_clean if not r.reproduction_metadata]
577573
),
578574
"n_gradings_failed": len(
579575
[r for r in results_clean if not r.judge_output or not r.judge_output.success]
@@ -595,36 +591,32 @@ async def get_full_summary(
595591
other_stats = {
596592
"repro_mean_time": safe_mean(
597593
[
598-
r.reproduction_output.metadata.repro_execution_time # type: ignore
594+
r.reproduction_metadata.repro_execution_time
599595
for r in results_clean
600-
if r.reproduction_output and r.reproduction_output.success
596+
if r.reproduction_metadata and r.reproduction_metadata.repro_execution_time
601597
]
602598
),
603599
"n_is_valid_git_repo": len(
604600
[
605601
r
606602
for r in results_clean
607-
if r.reproduction_output
608-
and r.reproduction_output.success
609-
and r.reproduction_output.metadata.is_valid_git_repo # type: ignore
603+
if r.reproduction_metadata and r.reproduction_metadata.is_valid_git_repo
610604
]
611605
),
612606
"n_nontrivial_git_log": len(
613607
[
614608
r
615609
for r in results_clean
616-
if r.reproduction_output
617-
and r.reproduction_output.success
618-
and len(r.reproduction_output.metadata.git_log.strip().splitlines()) > 1 # type: ignore
610+
if r.reproduction_metadata
611+
and r.reproduction_metadata.git_log is not None
612+
and len(r.reproduction_metadata.git_log.strip().splitlines()) > 1
619613
]
620614
),
621615
"n_repro_script_exists": len(
622616
[
623617
r
624618
for r in results_clean
625-
if r.reproduction_output
626-
and r.reproduction_output.success
627-
and r.reproduction_output.metadata.repro_script_exists # type: ignore
619+
if r.reproduction_metadata and r.reproduction_metadata.repro_script_exists
628620
]
629621
),
630622
}

project/paperbench/paperbench/nano/structs.py

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import os
44
from dataclasses import dataclass
5-
from pathlib import Path
6-
from typing import Any
5+
from typing import Any, Self
76

87
from dotenv import load_dotenv
98

@@ -15,7 +14,7 @@
1514
from nanoeval.solvers.computer_tasks.task import Grade
1615
from preparedness_turn_completer.oai_turn_completer import OpenAITurnCompleter
1716
from preparedness_turn_completer.turn_completer import TurnCompleter
18-
from pydantic import BaseModel
17+
from pydantic import BaseModel, model_validator
1918

2019
from paperbench.agents.utils import (
2120
AgentOutput,
@@ -28,38 +27,6 @@
2827
logger = structlog.stdlib.get_logger(component=__name__)
2928

3029

31-
class ReproductionOutput(BaseModel):
32-
executed_submission: Path | str | None = None
33-
metadata: ReproductionMetadata | None = None
34-
35-
@classmethod
36-
def from_dict(cls, data: dict[str, Any]) -> ReproductionOutput:
37-
metadata_exists = data.get("metadata") is not None
38-
39-
if metadata_exists:
40-
metadata = ReproductionMetadata.from_dict(data["metadata"])
41-
else:
42-
metadata = None
43-
44-
try:
45-
return cls(
46-
executed_submission=data.get("executed_submission"),
47-
metadata=metadata,
48-
)
49-
except KeyError as e:
50-
raise ValueError("Missing required field in reproduction output") from e
51-
52-
def to_dict(self) -> dict[str, Any]:
53-
return {
54-
"executed_submission": self.executed_submission,
55-
"metadata": self.metadata.to_dict() if self.metadata else None,
56-
}
57-
58-
@property
59-
def success(self) -> bool:
60-
return self.metadata is not None
61-
62-
6330
@dataclass(frozen=False)
6431
class PaperBenchResult:
6532
paper_id: str
@@ -70,7 +37,7 @@ class PaperBenchResult:
7037
resources_provided: bool
7138
agent_output: AgentOutput | None = None
7239
judge_output: JudgeOutput | None = None
73-
reproduction_output: ReproductionOutput | None = None
40+
reproduction_metadata: ReproductionMetadata | None = None
7441
monitor_result: MonitorResult | None = None
7542
monitor_ran: bool = False
7643

@@ -84,7 +51,7 @@ def to_dict(self) -> dict[str, Any]:
8451
"resources_provided": self.resources_provided,
8552
"agent_output": None,
8653
"judge_output": None,
87-
"reproduction_output": None,
54+
"reproduction_metadata": None,
8855
"monitor_result": None,
8956
"monitor_ran": self.monitor_ran,
9057
}
@@ -95,8 +62,8 @@ def to_dict(self) -> dict[str, Any]:
9562
if self.judge_output:
9663
data["judge_output"] = self.judge_output.to_dict()
9764

98-
if self.reproduction_output:
99-
data["reproduction_output"] = self.reproduction_output.to_dict()
65+
if self.reproduction_metadata:
66+
data["reproduction_metadata"] = self.reproduction_metadata.to_dict()
10067

10168
if self.monitor_result:
10269
data["monitor_result"] = self.monitor_result.to_dict()
@@ -106,6 +73,7 @@ def to_dict(self) -> dict[str, Any]:
10673

10774
class ReproductionConfig(BaseModel):
10875
timeout: int = 100 * 3600
76+
# if the reproduce.sh runs for less than this, it will be retried with salvaging fixes
10977
retry_threshold: float = 600
11078
overwrite_existing_output: bool = False
11179
skip_reproduction: bool = False
@@ -114,6 +82,14 @@ class ReproductionConfig(BaseModel):
11482
pull_from_registry=False,
11583
)
11684

85+
@model_validator(mode="after")
86+
def _validate_timeout_and_retry_threshold(self) -> Self:
87+
if self.retry_threshold >= self.timeout:
88+
logger.warning(
89+
"ReproductionConfig.retry_threshold >= ReproductionConfig.timeout, so reproduce.sh salvaging is disabled.",
90+
)
91+
return self
92+
11793

11894
class JudgeConfig(BaseModel):
11995
grade: bool = True

project/paperbench/paperbench/nano/task.py

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile
66
import time
77
from contextlib import asynccontextmanager, nullcontext
8-
from dataclasses import asdict
8+
from dataclasses import asdict, replace
99
from datetime import timedelta
1010
from pathlib import Path
1111
from typing import Any, AsyncGenerator
@@ -42,12 +42,10 @@
4242
PaperBenchGrade,
4343
PaperBenchResult,
4444
ReproductionConfig,
45-
ReproductionOutput,
4645
)
4746
from paperbench.nano.utils import get_file_at_duration
4847
from paperbench.paper_registry import paper_registry
49-
from paperbench.scripts.alcatraz_services import reproduce_on_computer
50-
from paperbench.scripts.run_reproduce import ReproductionMetadata
48+
from paperbench.scripts.run_reproduce import ReproductionMetadata, reproduce_on_computer
5149
from paperbench.utils import purple
5250

5351
GRADER_OPENAI_API_KEY = os.getenv("GRADER_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
@@ -197,7 +195,7 @@ def _early_exit_grade(
197195
resources_provided=self.judge.resources_provided,
198196
agent_output=None,
199197
judge_output=None,
200-
reproduction_output=None,
198+
reproduction_metadata=None,
201199
monitor_result=monitor_result,
202200
monitor_ran=monitor_ran,
203201
),
@@ -208,12 +206,12 @@ def _early_exit_grade(
208206
self._save_grade(grade)
209207
return grade
210208

211-
def _should_grade(self, reproduction_output: ReproductionOutput | None) -> bool:
209+
def _should_grade(self, reproduction_metadata: ReproductionMetadata | None) -> bool:
212210
"""
213211
We can proceed with grading if reproduction was successful
214212
OR we are in a reproduction-free setup
215213
"""
216-
return (reproduction_output and reproduction_output.success) or (
214+
return (reproduction_metadata is not None) or (
217215
self.reproduction.skip_reproduction or self.judge.code_only
218216
)
219217

@@ -277,17 +275,17 @@ async def grade(
277275
)
278276

279277
# 3. run reproduction
280-
repro_output = None
278+
repro_metadata = None
281279
submission_to_grade_path = path_to_submission
282280
if self._should_reproduce():
283-
repro_output = await self._run_reproduce(path_to_submission)
284-
repro_metadata = repro_output.metadata.to_dict() if repro_output.metadata else {}
281+
repro_metadata = await self._run_reproduce(path_to_submission)
282+
repro_metadata_dict = repro_metadata.to_dict() if repro_metadata else {}
285283
submission_to_grade_path = path_to_executed_submission
286-
self._record_extra({"repro_metadata": repro_metadata})
284+
self._record_extra({"repro_metadata": repro_metadata_dict})
287285

288286
# 4. run judge
289287
judge_output = None
290-
if self._should_grade(repro_output):
288+
if self._should_grade(repro_metadata):
291289
judge_output = await self._run_judge(submission_to_grade_path, self.paper_id)
292290
self._record_extra({"judge_output": judge_output.to_dict() if judge_output else None})
293291

@@ -301,7 +299,7 @@ async def grade(
301299
code_only=self.judge.code_only,
302300
resources_provided=self.judge.resources_provided,
303301
judge_output=judge_output,
304-
reproduction_output=repro_output,
302+
reproduction_metadata=repro_metadata,
305303
monitor_ran=mon_ran,
306304
monitor_result=mon_result,
307305
),
@@ -346,24 +344,23 @@ def _run_monitor(self, log_file_path: str) -> MonitorResult:
346344
monitor_result = monitor.check_log(log_file_path)
347345
return monitor_result
348346

349-
async def _run_reproduce(self, submission: str) -> ReproductionOutput:
347+
async def _run_reproduce(self, submission: str) -> ReproductionMetadata | None:
350348
"""Runs the reproduction process for the submission associated with the PBTask."""
351349
ctx_logger = logger.bind(
352350
run_group_id=self.run_group_id,
353351
run_id=self.run_id,
354352
runs_dir=self.runs_dir,
355353
)
356354
ctx_logger.info(
357-
f"Starting the reproduction process for `{self.question_id}.{self.attempt_id}`...",
355+
f"Starting the reproduction process for `{self.run_id}`...",
358356
destinations=["group", "run"],
359357
_print=True,
360358
)
361359

360+
metadata: ReproductionMetadata | None = None
362361
reproduce_output_path = submission.replace(".tar.gz", "_executed.tar.gz")
363362
repro_metadata_path = submission.replace(".tar.gz", "_executed_metadata.json")
364363

365-
ctx_logger.info(f"Reproducing submission {reproduce_output_path}...", destinations=["run"])
366-
367364
# If the reproduction output already exists, we can skip reproduction
368365
if not self.reproduction.overwrite_existing_output:
369366
repro_output_exists = bf.exists(reproduce_output_path)
@@ -374,62 +371,37 @@ async def _run_reproduce(self, submission: str) -> ReproductionOutput:
374371
destinations=["run"],
375372
)
376373
with bf.BlobFile(repro_metadata_path, "r") as f:
377-
data = json.loads(f.read())
378-
metadata = ReproductionMetadata.from_dict(data)
379-
return ReproductionOutput(
380-
executed_submission=reproduce_output_path,
381-
metadata=metadata,
382-
)
383-
384-
# Reproduce on alcatraz
385-
async with self._start_computer(self.reproduction.cluster_config) as computer:
386-
await reproduce_on_computer(
387-
computer=computer,
374+
metadata = ReproductionMetadata.from_dict(json.loads(f.read()))
375+
metadata = replace(metadata, executed_submission=reproduce_output_path)
376+
return metadata
377+
378+
# Reproduce on alcatraz and collect metadata
379+
try:
380+
metadata = await reproduce_on_computer(
381+
cluster_config=self.reproduction.cluster_config,
388382
submission_path=submission,
389383
logger=ctx_logger.bind(destinations=["run"]),
390384
run_dir=self.run_dir,
391385
timeout=self.reproduction.timeout,
392386
retry_threshold=self.reproduction.retry_threshold,
393387
)
388+
except Exception as e:
389+
logger.exception(f"Reproduction failed with error:\n{str(e)}")
394390

395-
# Now the result should exist
396-
repro_output_exists = bf.exists(reproduce_output_path)
397-
repro_metadata_exists = bf.exists(repro_metadata_path)
398-
if not repro_output_exists:
399-
ctx_logger.exception(
400-
f"Reproduction failed to produce output: {reproduce_output_path}",
401-
destinations=["group", "run"],
402-
_print=True,
403-
)
404-
return ReproductionOutput(
405-
executed_submission=reproduce_output_path,
406-
metadata=None,
407-
)
408-
if not repro_metadata_exists:
391+
if metadata is None:
409392
ctx_logger.exception(
410393
f"Reproduction failed to produce metadata: {repro_metadata_path}",
411394
destinations=["group", "run"],
412395
_print=True,
413396
)
414-
return ReproductionOutput(
415-
executed_submission=reproduce_output_path,
416-
metadata=None,
417-
)
418-
419-
with bf.BlobFile(repro_metadata_path, "r") as f:
420-
data = json.loads(f.read())
421-
metadata = ReproductionMetadata.from_dict(data)
422397

423398
ctx_logger.info(
424-
f"The reproduction process for {self.question_id}.{self.attempt_id} has finished!",
399+
f"The reproduction process for {self.run_id} has finished!",
425400
destinations=["group", "run"],
426401
_print=True,
427402
)
428403

429-
return ReproductionOutput(
430-
executed_submission=reproduce_output_path,
431-
metadata=metadata,
432-
)
404+
return metadata
433405

434406
async def _select_checkpoint(self) -> tuple[str, timedelta] | None:
435407
"""Identifies the submission tarball to use for reproduction/grading."""

project/paperbench/paperbench/reproducer.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ENV DEBIAN_FRONTEND=noninteractive
88
RUN apt-get update && \
99
apt-get install -y \
1010
software-properties-common \
11-
wget curl unzip \
11+
wget curl unzip sudo \
1212
build-essential git cmake \
1313
libatlas-base-dev libblas-dev liblapack-dev libopenblas-dev \
1414
gfortran libsm6 libxext6 libxrender-dev && \

0 commit comments

Comments
 (0)