Skip to content

Commit 29f5a0b

Browse files
committed
add dependecies; add model name to output paths
1 parent 08b339a commit 29f5a0b

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

eval/inspection_ai/scicode.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def save_prompt_with_steps(
119119
prompt: str,
120120
num_steps: int
121121
) -> None:
122-
output_dir = Path(self.prompt_dir, self._get_background_dir())
122+
output_dir = Path(
123+
self.prompt_dir,
124+
self._get_background_dir()
125+
)
123126
output_dir.mkdir(parents=True, exist_ok=True)
124127
output_file_path = output_dir / f"{prob_data['problem_id']}.{num_steps}.txt"
125128
output_file_path.write_text(prompt, encoding="utf-8")
@@ -185,8 +188,8 @@ class ScicodeEvaluator:
185188
def __init__(
186189
self,
187190
h5py_file: str,
188-
code_dir: str,
189-
log_dir: str,
191+
code_dir: Path,
192+
log_dir: Path,
190193
with_background: bool,
191194
):
192195
self.h5py_file = h5py_file
@@ -306,9 +309,10 @@ def generate_gold_response(prob_data: dict, num_steps: int):
306309
@solver
307310
def scicode_solver(**params: dict[str, Any]):
308311
async def solve(state: TaskState, generate: Generate) -> TaskState:
312+
model_name = str(state.model).replace("/", "-")
309313
prompt_assistant = ScicodePromptingAssistant(
310-
output_dir=Path(params["output_dir"], "generated_code"),
311-
prompt_dir=Path(params["output_dir"], "prompt"),
314+
output_dir=Path(params["output_dir"], model_name, "generated_code"),
315+
prompt_dir=Path(params["output_dir"], model_name, "prompt"),
312316
with_background=params["with_background"],
313317
)
314318
prompt_template = BACKGOUND_PROMPT_TEMPLATE if params["with_background"] else DEFAULT_PROMPT_TEMPLATE
@@ -365,10 +369,11 @@ def metric(scores: list[Score]) -> int | float:
365369
)
366370
def scicode_scorer(**params: dict[str, Any]):
367371
async def score(state: TaskState, target: Target):
372+
model_name = str(state.model).replace("/", "-")
368373
evaluator = ScicodeEvaluator(
369374
h5py_file=params["h5py_file"],
370-
code_dir=params["output_dir"],
371-
log_dir=params["output_dir"],
375+
code_dir=Path(params["output_dir"], model_name),
376+
log_dir=Path(params["output_dir"], model_name),
372377
with_background=params["with_background"],
373378
)
374379
problem_correct, total_correct, total_steps = evaluator.test_code(state.metadata)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"scipy",
3838
"matplotlib",
3939
"sympy",
40+
"inspect-ai",
4041
]
4142

4243
# Classifiers help users find your project by categorizing it.

0 commit comments

Comments
 (0)