Skip to content

Commit 898fa5a

Browse files
Resolve Task Metrics Immediately (UKGovernmentBEIS#2439)
* Resolve Task Metrics Immediately Rather than pass task metrics thorugh to results computation, resolve the task metrics immediately onto the scorers and simply use those when computing results. * Add tests * Update CHANGELOG.md * correct changelog --------- Co-authored-by: jjallaire <jj.allaire@gmail.com>
1 parent 85ef118 commit 898fa5a

File tree

6 files changed

+108
-16
lines changed

6 files changed

+108
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Sandbox Service: Specify instance externally so a single script can service multiple instances.
55
- Agent Bridge: Capture message history in agent state for all bridge generations.
66
- Agent Bridge: Embed sandbox service client in sandbox bridge proxy (for ease of bundling).
7+
- Scoring: Resolve task or eval level metrics onto scorers immediately rather than waiting until scoring.
78
- Inspect View: Add support for cmd + arrow up/down to navigate the samples list.
89
- Inspect View: Improve scroll keyboard handling in sample transcript view.
910
- Inspect View: Improve scroll keyboard handling in sample messages view.

src/inspect_ai/_eval/score.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from inspect_ai._display import display as display_manager
1616
from inspect_ai._eval.context import init_task_context
1717
from inspect_ai._eval.loader import scorer_from_spec
18+
from inspect_ai._eval.task.task import resolve_scorer_metrics
1819
from inspect_ai._util._async import configured_async_backend, run_coroutine, tg_collect
1920
from inspect_ai._util.platform import platform_init, running_in_notebook
2021
from inspect_ai._util.registry import registry_create, registry_unqualified_name
@@ -247,6 +248,9 @@ async def _score_sample(idx_sample: int) -> None:
247248
# that will be taken care of in eval_results)
248249
log_metrics = metrics_from_log_header(log)
249250

251+
# resolve the scorer metrics onto the scorers
252+
scorers = resolve_scorer_metrics(scorers, log_metrics) or []
253+
250254
# override epochs_reducer if specified
251255
epochs_reducer = create_reducers(epochs_reducer)
252256
if epochs_reducer:
@@ -260,7 +264,6 @@ async def _score_sample(idx_sample: int) -> None:
260264
list(filter(None, scores)),
261265
epochs_reducer,
262266
scorers,
263-
log_metrics,
264267
)
265268

266269
return log

src/inspect_ai/_eval/task/results.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def eval_results(
6969
scores: list[dict[str, SampleScore]],
7070
reducers: ScoreReducer | list[ScoreReducer] | None,
7171
scorers: list[Scorer] | None,
72-
metrics: list[Metric] | dict[str, list[Metric]] | None,
7372
) -> Tuple[EvalResults, list[EvalSampleReductions] | None]:
7473
# initialise results
7574
results = EvalResults(total_samples=samples, completed_samples=len(scores))
@@ -105,11 +104,9 @@ def eval_results(
105104
if len(reducers) == 0:
106105
# Compute metrics without reduction since no reducers were
107106
# explicitly specified
108-
targets = metrics if metrics is not None else scorer_info.metrics
109-
110107
eval_scores = compute_eval_scores(
111108
resolved_scores,
112-
targets,
109+
scorer_info.metrics,
113110
scorer_name,
114111
scorer_info,
115112
None,
@@ -135,11 +132,9 @@ def eval_results(
135132
sample_reductions.append(reduced_samples)
136133

137134
# Compute metrics for this scorer
138-
targets = metrics if metrics is not None else scorer_info.metrics
139-
140135
eval_scores = compute_eval_scores(
141136
reduced_scores,
142-
targets,
137+
scorer_info.metrics,
143138
scorer_name,
144139
scorer_info,
145140
reducer_display_nm,

src/inspect_ai/_eval/task/run.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
)
7878
from inspect_ai.model._model import init_sample_model_usage, sample_model_usage
7979
from inspect_ai.scorer import Scorer, Target
80-
from inspect_ai.scorer._metric import Metric, SampleScore
80+
from inspect_ai.scorer._metric import SampleScore
8181
from inspect_ai.scorer._reducer.types import ScoreReducer
8282
from inspect_ai.scorer._score import init_scoring_context
8383
from inspect_ai.scorer._scorer import unique_scorer_name
@@ -311,7 +311,6 @@ def sample_complete(sample_score: dict[str, SampleScore]) -> None:
311311
progress_results,
312312
scorers,
313313
task.epochs_reducer,
314-
task.metrics,
315314
)
316315

317316
# initial progress
@@ -323,7 +322,6 @@ def sample_complete(sample_score: dict[str, SampleScore]) -> None:
323322
progress_results,
324323
scorers,
325324
task.epochs_reducer,
326-
task.metrics,
327325
)
328326

329327
async def run_sample(
@@ -383,7 +381,6 @@ async def run_sample(
383381
scores=completed_scores,
384382
reducers=task.epochs_reducer,
385383
scorers=scorers,
386-
metrics=task.metrics,
387384
)
388385

389386
# collect eval data
@@ -477,7 +474,6 @@ def update_metrics_display_fn(
477474
list[dict[str, SampleScore]],
478475
list[Scorer] | None,
479476
ScoreReducer | list[ScoreReducer] | None,
480-
list[Metric] | dict[str, list[Metric]] | None,
481477
],
482478
None,
483479
]:
@@ -488,7 +484,6 @@ def compute(
488484
sample_scores: list[dict[str, SampleScore]],
489485
scorers: list[Scorer] | None,
490486
reducers: ScoreReducer | list[ScoreReducer] | None,
491-
metrics: list[Metric] | dict[str, list[Metric]] | None,
492487
) -> None:
493488
# Don't compute metrics if they are not being displayed
494489
if not display_metrics:
@@ -503,7 +498,6 @@ def compute(
503498
scores=sample_scores,
504499
reducers=reducers,
505500
scorers=scorers,
506-
metrics=metrics,
507501
)
508502

509503
# Name, reducer, value

src/inspect_ai/_eval/task/task.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from inspect_ai._util.logger import warn_once
99
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
1010
from inspect_ai._util.registry import (
11+
RegistryInfo,
1112
is_registry_object,
1213
registry_info,
1314
registry_unqualified_name,
15+
set_registry_info,
1416
)
1517
from inspect_ai.agent._agent import Agent, is_agent
1618
from inspect_ai.agent._as_solver import as_solver
@@ -141,7 +143,7 @@ def __init__(
141143
self.setup = setup
142144
self.solver = resolve_solver(solver)
143145
self.cleanup = cleanup
144-
self.scorer = resolve_scorer(scorer)
146+
self.scorer = resolve_scorer_metrics(resolve_scorer(scorer), metrics)
145147
self.metrics = metrics
146148
self.model = resolve_model(model)
147149
self.config = config
@@ -417,3 +419,17 @@ def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
417419
return (
418420
scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
419421
)
422+
423+
424+
def resolve_scorer_metrics(
425+
scorers: list[Scorer] | None, metrics: list[Metric] | dict[str, list[Metric]] | None
426+
) -> list[Scorer] | None:
427+
if scorers is not None and metrics is not None:
428+
for scorer in scorers:
429+
scorer_info = registry_info(scorer)
430+
new_metadata = {**scorer_info.metadata, "metrics": metrics}
431+
new_info = RegistryInfo(
432+
type=scorer_info.type, name=scorer_info.name, metadata=new_metadata
433+
)
434+
set_registry_info(scorer, new_info)
435+
return scorers

tests/scorer/test_task_scorer.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from inspect_ai import eval
2+
from inspect_ai._eval.task.task import Task
3+
from inspect_ai._util.registry import registry_info
4+
from inspect_ai.dataset._sources.csv import csv_dataset
5+
from inspect_ai.scorer._answer import answer
6+
from inspect_ai.scorer._classification import f1
7+
from inspect_ai.scorer._metrics import accuracy, mean
8+
from inspect_ai.scorer._metrics.std import bootstrap_stderr
9+
10+
11+
def test_task_with_metrics():
12+
task = Task(scorer=f1(), metrics=[mean(), bootstrap_stderr()])
13+
14+
# ensure that metrics themselves remain unchanged
15+
assert registry_info(task.metrics[0]).name == "inspect_ai/mean"
16+
assert registry_info(task.metrics[1]).name == "inspect_ai/bootstrap_stderr"
17+
assert task.scorer is not None
18+
19+
# ensure that the task metrics are correctly applied to the scorer
20+
info = registry_info(task.scorer[0])
21+
assert registry_info(info.metadata["metrics"][0]).name == "inspect_ai/mean"
22+
23+
info = registry_info(task.scorer[0])
24+
assert (
25+
registry_info(info.metadata["metrics"][1]).name == "inspect_ai/bootstrap_stderr"
26+
)
27+
28+
# modify the task and ensure that the new metrics remain unchanged
29+
task.scorer.append(answer("word"))
30+
assert len(task.scorer) == 2
31+
info = registry_info(task.scorer[1])
32+
assert registry_info(info.metadata["metrics"][0]).name == "inspect_ai/accuracy"
33+
assert registry_info(info.metadata["metrics"][1]).name == "inspect_ai/stderr"
34+
35+
36+
def test_task_score_results():
37+
task = Task(
38+
dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"),
39+
scorer=f1(),
40+
metrics=[accuracy()],
41+
)
42+
43+
# confirm the mean result is computed
44+
log = eval(task, model="mockllm/model", sandbox=False)
45+
assert len(log[0].results.scores) == 1
46+
assert len(log[0].results.scores[0].metrics) == 1
47+
assert "accuracy" in log[0].results.scores[0].metrics
48+
assert "mean" not in log[0].results.scores[0].metrics
49+
assert "stderr" not in log[0].results.scores[0].metrics
50+
51+
52+
def test_score_results():
53+
task = Task(
54+
dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"),
55+
scorer=f1(),
56+
)
57+
58+
# confirm the mean result is computed
59+
log = eval(task, model="mockllm/model")
60+
assert len(log[0].results.scores) == 1
61+
assert len(log[0].results.scores[0].metrics) == 2
62+
assert "mean" in log[0].results.scores[0].metrics
63+
assert "stderr" in log[0].results.scores[0].metrics
64+
65+
66+
def test_added_scores():
67+
task = Task(
68+
dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"),
69+
scorer=f1(),
70+
metrics=[accuracy()],
71+
)
72+
task.scorer.append(answer("line"))
73+
74+
log = eval(task, model="mockllm/model")
75+
assert len(log[0].results.scores) == 2
76+
assert len(log[0].results.scores[0].metrics) == 1
77+
assert "accuracy" in log[0].results.scores[0].metrics
78+
assert "mean" not in log[0].results.scores[0].metrics
79+
assert "stderr" not in log[0].results.scores[0].metrics
80+
81+
assert len(log[0].results.scores[1].metrics) == 2
82+
assert "accuracy" in log[0].results.scores[1].metrics
83+
assert "stderr" in log[0].results.scores[1].metrics

0 commit comments

Comments
 (0)