Skip to content

Commit ac8aeb7

Browse files
committed
Fix bug in ToolSelectionAccuracy
1 parent 28168b7 commit ac8aeb7

File tree

1 file changed

+48
-40
lines changed

1 file changed

+48
-40
lines changed

continuous_eval/metrics/tools/match.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,69 @@
44
from continuous_eval.metrics.base import Field, Metric
55

66

7+
def _count_matches(ground_truth, tools, order_sensitive=False):
8+
if order_sensitive:
9+
# For order-sensitive matching
10+
matches = 0
11+
gt_index = 0
12+
13+
for tool in tools:
14+
if gt_index < len(ground_truth) and ground_truth[gt_index] == tool:
15+
matches += 1
16+
gt_index += 1
17+
return matches
18+
else:
19+
# For order-insensitive matching, convert dictionaries to hashable tuples
20+
def make_hashable(obj):
21+
if isinstance(obj, dict):
22+
return tuple(
23+
sorted((k, make_hashable(v)) for k, v in obj.items())
24+
)
25+
elif isinstance(obj, list):
26+
return tuple(make_hashable(v) for v in obj)
27+
else:
28+
return obj
29+
30+
ground_truth_set = set(make_hashable(d) for d in ground_truth)
31+
tools_set = set(make_hashable(d) for d in tools)
32+
intersection = ground_truth_set & tools_set
33+
return len(intersection)
34+
35+
736
class ToolSelectionAccuracy(Metric):
837
"""
938
Computes the accuracy of tool selection.
1039
"""
1140

12-
def __init__(self, order_sensitive: bool = False) -> None:
41+
def __init__(
42+
self,
43+
order_sensitive: bool = False,
44+
ignore_kwargs: bool = False,
45+
) -> None:
1346
super().__init__(is_cpu_bound=True)
1447
self._order_sensitive = order_sensitive
48+
self._ignore_kwargs = ignore_kwargs
1549

1650
def compute(
1751
self, tools: List[ToolCall], ground_truths: List[ToolCall], **kwargs
1852
):
19-
if self._order_sensitive:
20-
# When order matters, compare tool executions directly in sequence.
21-
num_correct = sum(
22-
1
23-
for i, tool in enumerate(tools)
24-
if i < len(ground_truths)
25-
and tool["name"] == ground_truths[i]["name"]
26-
and tool["kwargs"] == ground_truths[i]["kwargs"]
27-
)
53+
if self._ignore_kwargs:
54+
_ground_truths = [{"name": t["name"]} for t in ground_truths]
55+
_tools = [{"name": t["name"]} for t in tools]
2856
else:
29-
# Convert ground_truth to a format that's easy to check for "contains"
30-
use_kwargs = all("kwargs" in tool for tool in ground_truths)
31-
if use_kwargs:
32-
ground_truth_set = {
33-
frozenset(tool.items())
34-
for tool in [
35-
{"name": tool["name"], **tool["kwargs"]}
36-
for tool in ground_truths
37-
]
38-
}
39-
else:
40-
ground_truth_set = {
41-
frozenset(tool.items()) for tool in ground_truths
42-
}
43-
# Score
44-
num_correct, matched_executions = 0, set()
45-
for tool in tools:
46-
if use_kwargs:
47-
tool_set = frozenset(
48-
{"name": tool["name"], **tool["kwargs"]}.items()
49-
)
50-
else:
51-
tool_set = frozenset({"name": tool["name"]}.items())
52-
if (
53-
tool_set in ground_truth_set
54-
and tool_set not in matched_executions
55-
):
56-
num_correct += 1
57-
matched_executions.add(tool_set)
57+
_ground_truths, _tools = ground_truths, tools
58+
num_correct = _count_matches(
59+
_ground_truths, _tools, order_sensitive=self._order_sensitive
60+
)
61+
score = 1.0
62+
if len(ground_truths) > 0:
63+
score = num_correct / len(ground_truths)
64+
elif len(tools) > 0:
65+
score = 0.0
5866

5967
return {
6068
"num_correct": num_correct,
61-
"score": num_correct / len(ground_truths),
69+
"score": score,
6270
}
6371

6472
@property

0 commit comments

Comments
 (0)