Skip to content

Commit 61512df

Browse files
authored
Merge pull request #1235 from vespa-engine/boeker/handle-internal-ids-recall
Handle internal IDs in recall computation
2 parents e3afd7f + 20bbe85 commit 61512df

File tree

2 files changed

+120
-3
lines changed

2 files changed

+120
-3
lines changed

tests/unit/test_evaluator.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,6 +3311,81 @@ def test_compute_recall(self):
33113311
delta=0.0001,
33123312
)
33133313

3314+
def test_compute_recall_id_field(self):
3315+
response_exact = self.SuccessfullMockVespaResponse(
3316+
[
3317+
{"id": "1", "fields": {"id": "1"}},
3318+
{"id": "2", "fields": {"id": "2"}},
3319+
{"id": "3", "fields": {"id": "3"}},
3320+
{"id": "4", "fields": {"id": "4"}},
3321+
{"id": "5", "fields": {"id": "5"}},
3322+
]
3323+
)
3324+
self.assertAlmostEqual(
3325+
self.recall_evaluator._compute_recall(response_exact, response_exact),
3326+
1.0,
3327+
delta=0.0001,
3328+
)
3329+
3330+
response_approx = self.SuccessfullMockVespaResponse(
3331+
[
3332+
{"id": "1", "fields": {"id": "1"}},
3333+
{"id": "2", "fields": {"id": "2"}},
3334+
{"id": "3", "fields": {"id": "3"}},
3335+
{"id": "4", "fields": {"id": "4"}},
3336+
]
3337+
)
3338+
self.assertAlmostEqual(
3339+
self.recall_evaluator._compute_recall(response_exact, response_approx),
3340+
0.8,
3341+
delta=0.0001,
3342+
)
3343+
3344+
class InternalIDResponse(MockVespaResponse):
3345+
def __init__(
3346+
self,
3347+
hits,
3348+
first_node_id=0,
3349+
_total_count=None,
3350+
_timing=None,
3351+
_status_code=200,
3352+
):
3353+
super().__init__(hits, _total_count, _timing, _status_code)
3354+
self.next_node_num = first_node_id
3355+
3356+
def add_namespace_to_hit_ids(self, hits_list) -> List[Dict[str, Any]]:
3357+
new_hits = []
3358+
for hit_item in hits_list:
3359+
hit_id = hit_item.get("id")
3360+
if isinstance(hit_id, str) and not hit_id.startswith("index:"):
3361+
hit_item["id"] = f"index:cluster/{self.next_node_num}/{hit_id}"
3362+
self.next_node_num += 1
3363+
new_hits.append(hit_item)
3364+
return new_hits
3365+
3366+
def is_successful(self):
3367+
return True
3368+
3369+
def test_compute_recall_internal_ids(self):
3370+
response_exact = self.InternalIDResponse(
3371+
[{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}, {"id": "5"}],
3372+
first_node_id=0,
3373+
)
3374+
self.assertAlmostEqual(
3375+
self.recall_evaluator._compute_recall(response_exact, response_exact),
3376+
1.0,
3377+
delta=0.0001,
3378+
)
3379+
3380+
response_approx = self.InternalIDResponse(
3381+
[{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}], first_node_id=1
3382+
)
3383+
self.assertAlmostEqual(
3384+
self.recall_evaluator._compute_recall(response_exact, response_approx),
3385+
0.8,
3386+
delta=0.0001,
3387+
)
3388+
33143389
def test_run(self):
33153390
class MockVespaApp:
33163391
def __init__(self, first_mock_responses, second_mock_responses):

vespa/evaluation/_base.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,7 @@ class VespaNNRecallEvaluator:
18651865
hits (int): Number of hits to use. Should match the parameter targetHits in the used ANN queries.
18661866
app (Vespa): An instance of the Vespa application.
18671867
query_limit (int): Maximum number of queries to determine the recall for. Defaults to 20.
1868+
id_field (str): Name of the field containing a unique id. Defaults to "id".
18681869
**kwargs (dict, optional): Additional HTTP request parameters. See: <https://docs.vespa.ai/en/reference/document-v1-api-reference.html#request-parameters>.
18691870
"""
18701871

@@ -1874,12 +1875,14 @@ def __init__(
18741875
hits: int,
18751876
app: Vespa,
18761877
query_limit: int = 20,
1878+
id_field: str = "id",
18771879
**kwargs,
18781880
):
18791881
self.queries = queries
18801882
self.hits = hits
18811883
self.app = app
18821884
self.query_limit = query_limit
1885+
self.id_field = id_field
18831886
self.parameters = kwargs
18841887

18851888
def _compute_recall(
@@ -1904,8 +1907,39 @@ def _compute_recall(
19041907
except KeyError:
19051908
results_approx = []
19061909

1907-
ids_exact = list(map(lambda x: x["id"], results_exact))
1908-
ids_approx = list(map(lambda x: x["id"], results_approx))
1910+
def extract_id(hit: dict, id_field: str) -> Tuple[str, str]:
1911+
"""Extract document ID from a Vespa hit."""
1912+
1913+
# id as specified by field
1914+
fields = hit.get("fields", {})
1915+
if id_field in fields:
1916+
return fields[id_field], "id_field"
1917+
1918+
# document id
1919+
id = hit.get("id", "")
1920+
if "::" in id:
1921+
return id, "document_id"
1922+
1923+
# internal id
1924+
if id.startswith(
1925+
"index:"
1926+
): # id is an internal id of the form index:[source]/[node-index]/[hex-gid], return hex-gid
1927+
return id.split("/", 2)[2], "internal_id"
1928+
1929+
raise ValueError(f"Could not extract a document id from hit: {hit}")
1930+
1931+
ids_exact = list(map(lambda x: extract_id(x, self.id_field)[0], results_exact))
1932+
ids_approx = list(
1933+
map(lambda x: extract_id(x, self.id_field)[0], results_approx)
1934+
)
1935+
1936+
id_types = set()
1937+
id_types.update(map(lambda x: extract_id(x, self.id_field)[1], results_exact))
1938+
id_types.update(map(lambda x: extract_id(x, self.id_field)[1], results_approx))
1939+
if len(id_types) > 1:
1940+
print(
1941+
f"Warning: Multiple id types obtained for hits: {id_types}. The recall computation will not be reliable. Please specify id_field correctly."
1942+
)
19091943

19101944
if len(ids_exact) != self.hits:
19111945
print(
@@ -2125,6 +2159,7 @@ class VespaNNParameterOptimizer:
21252159
benchmark_time_limit (int): Time in milliseconds to spend per bucket benchmark. Defaults to 5000.
21262160
recall_query_limit(int): Number of queries per bucket to compute the recall for. Defaults to 20.
21272161
max_concurrent(int): Number of queries to execute concurrently during benchmark/recall calculation. Defaults to 10.
2162+
id_field (str): Name of the field containing a unique id for recall computation. Defaults to "id".
21282163
"""
21292164

21302165
def __init__(
@@ -2137,6 +2172,7 @@ def __init__(
21372172
benchmark_time_limit: int = 5000,
21382173
recall_query_limit: int = 20,
21392174
max_concurrent: int = 10,
2175+
id_field: str = "id",
21402176
):
21412177
self.app = app
21422178
self.queries = queries
@@ -2150,6 +2186,7 @@ def __init__(
21502186
self.benchmark_time_limit = benchmark_time_limit
21512187
self.recall_query_limit = recall_query_limit
21522188
self.max_concurrent = max_concurrent
2189+
self.id_field = id_field
21532190

21542191
def get_bucket_interval_width(self) -> float:
21552192
"""
@@ -2512,7 +2549,12 @@ def compute_average_recalls(self, **kwargs) -> BucketedMetricResults:
25122549
end="",
25132550
)
25142551
recall_evaluator = VespaNNRecallEvaluator(
2515-
bucket, self.hits, self.app, self.recall_query_limit, **kwargs
2552+
bucket,
2553+
self.hits,
2554+
self.app,
2555+
self.recall_query_limit,
2556+
self.id_field,
2557+
**kwargs,
25162558
)
25172559
recall_list = recall_evaluator.run()
25182560
results.append(recall_list)

0 commit comments

Comments
 (0)