|
4 | 4 | from continuous_eval.metrics.base import Field, Metric |
5 | 5 |
|
6 | 6 |
|
| 7 | +def _count_matches(ground_truth, tools, order_sensitive=False): |
| 8 | + if order_sensitive: |
| 9 | + # For order-sensitive matching |
| 10 | + matches = 0 |
| 11 | + gt_index = 0 |
| 12 | + |
| 13 | + for tool in tools: |
| 14 | + if gt_index < len(ground_truth) and ground_truth[gt_index] == tool: |
| 15 | + matches += 1 |
| 16 | + gt_index += 1 |
| 17 | + return matches |
| 18 | + else: |
| 19 | + # For order-insensitive matching, convert dictionaries to hashable tuples |
| 20 | + def make_hashable(obj): |
| 21 | + if isinstance(obj, dict): |
| 22 | + return tuple( |
| 23 | + sorted((k, make_hashable(v)) for k, v in obj.items()) |
| 24 | + ) |
| 25 | + elif isinstance(obj, list): |
| 26 | + return tuple(make_hashable(v) for v in obj) |
| 27 | + else: |
| 28 | + return obj |
| 29 | + |
| 30 | + ground_truth_set = set(make_hashable(d) for d in ground_truth) |
| 31 | + tools_set = set(make_hashable(d) for d in tools) |
| 32 | + intersection = ground_truth_set & tools_set |
| 33 | + return len(intersection) |
| 34 | + |
| 35 | + |
7 | 36 | class ToolSelectionAccuracy(Metric): |
8 | 37 | """ |
9 | 38 | Computes the accuracy of tool selection. |
10 | 39 | """ |
11 | 40 |
|
12 | | - def __init__(self, order_sensitive: bool = False) -> None: |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + order_sensitive: bool = False, |
| 44 | + ignore_kwargs: bool = False, |
| 45 | + ) -> None: |
13 | 46 | super().__init__(is_cpu_bound=True) |
14 | 47 | self._order_sensitive = order_sensitive |
| 48 | + self._ignore_kwargs = ignore_kwargs |
15 | 49 |
|
16 | 50 | def compute( |
17 | 51 | self, tools: List[ToolCall], ground_truths: List[ToolCall], **kwargs |
18 | 52 | ): |
19 | | - if self._order_sensitive: |
20 | | - # When order matters, compare tool executions directly in sequence. |
21 | | - num_correct = sum( |
22 | | - 1 |
23 | | - for i, tool in enumerate(tools) |
24 | | - if i < len(ground_truths) |
25 | | - and tool["name"] == ground_truths[i]["name"] |
26 | | - and tool["kwargs"] == ground_truths[i]["kwargs"] |
27 | | - ) |
| 53 | + if self._ignore_kwargs: |
| 54 | + _ground_truths = [{"name": t["name"]} for t in ground_truths] |
| 55 | + _tools = [{"name": t["name"]} for t in tools] |
28 | 56 | else: |
29 | | - # Convert ground_truth to a format that's easy to check for "contains" |
30 | | - use_kwargs = all("kwargs" in tool for tool in ground_truths) |
31 | | - if use_kwargs: |
32 | | - ground_truth_set = { |
33 | | - frozenset(tool.items()) |
34 | | - for tool in [ |
35 | | - {"name": tool["name"], **tool["kwargs"]} |
36 | | - for tool in ground_truths |
37 | | - ] |
38 | | - } |
39 | | - else: |
40 | | - ground_truth_set = { |
41 | | - frozenset(tool.items()) for tool in ground_truths |
42 | | - } |
43 | | - # Score |
44 | | - num_correct, matched_executions = 0, set() |
45 | | - for tool in tools: |
46 | | - if use_kwargs: |
47 | | - tool_set = frozenset( |
48 | | - {"name": tool["name"], **tool["kwargs"]}.items() |
49 | | - ) |
50 | | - else: |
51 | | - tool_set = frozenset({"name": tool["name"]}.items()) |
52 | | - if ( |
53 | | - tool_set in ground_truth_set |
54 | | - and tool_set not in matched_executions |
55 | | - ): |
56 | | - num_correct += 1 |
57 | | - matched_executions.add(tool_set) |
| 57 | + _ground_truths, _tools = ground_truths, tools |
| 58 | + num_correct = _count_matches( |
| 59 | + _ground_truths, _tools, order_sensitive=self._order_sensitive |
| 60 | + ) |
| 61 | + score = 1.0 |
| 62 | + if len(ground_truths) > 0: |
| 63 | + score = num_correct / len(ground_truths) |
| 64 | + elif len(tools) > 0: |
| 65 | + score = 0.0 |
58 | 66 |
|
59 | 67 | return { |
60 | 68 | "num_correct": num_correct, |
61 | | - "score": num_correct / len(ground_truths), |
| 69 | + "score": score, |
62 | 70 | } |
63 | 71 |
|
64 | 72 | @property |
|
0 commit comments