|
| 1 | +from inspect_ai import eval |
| 2 | +from inspect_ai._eval.task.task import Task |
| 3 | +from inspect_ai._util.registry import registry_info |
| 4 | +from inspect_ai.dataset._sources.csv import csv_dataset |
| 5 | +from inspect_ai.scorer._answer import answer |
| 6 | +from inspect_ai.scorer._classification import f1 |
| 7 | +from inspect_ai.scorer._metrics import accuracy, mean |
| 8 | +from inspect_ai.scorer._metrics.std import bootstrap_stderr |
| 9 | + |
| 10 | + |
| 11 | +def test_task_with_metrics(): |
| 12 | + task = Task(scorer=f1(), metrics=[mean(), bootstrap_stderr()]) |
| 13 | + |
| 14 | + # ensure that metrics themselves remain unchanged |
| 15 | + assert registry_info(task.metrics[0]).name == "inspect_ai/mean" |
| 16 | + assert registry_info(task.metrics[1]).name == "inspect_ai/bootstrap_stderr" |
| 17 | + assert task.scorer is not None |
| 18 | + |
| 19 | + # ensure that the task metrics are correctly applied to the scorer |
| 20 | + info = registry_info(task.scorer[0]) |
| 21 | + assert registry_info(info.metadata["metrics"][0]).name == "inspect_ai/mean" |
| 22 | + |
| 23 | + info = registry_info(task.scorer[0]) |
| 24 | + assert ( |
| 25 | + registry_info(info.metadata["metrics"][1]).name == "inspect_ai/bootstrap_stderr" |
| 26 | + ) |
| 27 | + |
| 28 | + # modify the task and ensure that the new metrics remain unchanged |
| 29 | + task.scorer.append(answer("word")) |
| 30 | + assert len(task.scorer) == 2 |
| 31 | + info = registry_info(task.scorer[1]) |
| 32 | + assert registry_info(info.metadata["metrics"][0]).name == "inspect_ai/accuracy" |
| 33 | + assert registry_info(info.metadata["metrics"][1]).name == "inspect_ai/stderr" |
| 34 | + |
| 35 | + |
| 36 | +def test_task_score_results(): |
| 37 | + task = Task( |
| 38 | + dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"), |
| 39 | + scorer=f1(), |
| 40 | + metrics=[accuracy()], |
| 41 | + ) |
| 42 | + |
| 43 | + # confirm the mean result is computed |
| 44 | + log = eval(task, model="mockllm/model", sandbox=False) |
| 45 | + assert len(log[0].results.scores) == 1 |
| 46 | + assert len(log[0].results.scores[0].metrics) == 1 |
| 47 | + assert "accuracy" in log[0].results.scores[0].metrics |
| 48 | + assert "mean" not in log[0].results.scores[0].metrics |
| 49 | + assert "stderr" not in log[0].results.scores[0].metrics |
| 50 | + |
| 51 | + |
| 52 | +def test_score_results(): |
| 53 | + task = Task( |
| 54 | + dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"), |
| 55 | + scorer=f1(), |
| 56 | + ) |
| 57 | + |
| 58 | + # confirm the mean result is computed |
| 59 | + log = eval(task, model="mockllm/model") |
| 60 | + assert len(log[0].results.scores) == 1 |
| 61 | + assert len(log[0].results.scores[0].metrics) == 2 |
| 62 | + assert "mean" in log[0].results.scores[0].metrics |
| 63 | + assert "stderr" in log[0].results.scores[0].metrics |
| 64 | + |
| 65 | + |
| 66 | +def test_added_scores(): |
| 67 | + task = Task( |
| 68 | + dataset=csv_dataset("tests/dataset/test_dataset/samples-md.csv"), |
| 69 | + scorer=f1(), |
| 70 | + metrics=[accuracy()], |
| 71 | + ) |
| 72 | + task.scorer.append(answer("line")) |
| 73 | + |
| 74 | + log = eval(task, model="mockllm/model") |
| 75 | + assert len(log[0].results.scores) == 2 |
| 76 | + assert len(log[0].results.scores[0].metrics) == 1 |
| 77 | + assert "accuracy" in log[0].results.scores[0].metrics |
| 78 | + assert "mean" not in log[0].results.scores[0].metrics |
| 79 | + assert "stderr" not in log[0].results.scores[0].metrics |
| 80 | + |
| 81 | + assert len(log[0].results.scores[1].metrics) == 2 |
| 82 | + assert "accuracy" in log[0].results.scores[1].metrics |
| 83 | + assert "stderr" in log[0].results.scores[1].metrics |
0 commit comments