@@ -1865,6 +1865,7 @@ class VespaNNRecallEvaluator:
18651865 hits (int): Number of hits to use. Should match the parameter targetHits in the used ANN queries.
18661866 app (Vespa): An instance of the Vespa application.
18671867 query_limit (int): Maximum number of queries to determine the recall for. Defaults to 20.
1868+ id_field (str): Name of the field containing a unique id. Defaults to "id".
18681869 **kwargs (dict, optional): Additional HTTP request parameters. See: <https://docs.vespa.ai/en/reference/document-v1-api-reference.html#request-parameters>.
18691870 """
18701871
@@ -1874,12 +1875,14 @@ def __init__(
18741875 hits : int ,
18751876 app : Vespa ,
18761877 query_limit : int = 20 ,
1878+ id_field : str = "id" ,
18771879 ** kwargs ,
18781880 ):
18791881 self .queries = queries
18801882 self .hits = hits
18811883 self .app = app
18821884 self .query_limit = query_limit
1885+ self .id_field = id_field
18831886 self .parameters = kwargs
18841887
18851888 def _compute_recall (
@@ -1904,8 +1907,39 @@ def _compute_recall(
19041907 except KeyError :
19051908 results_approx = []
19061909
1907- ids_exact = list (map (lambda x : x ["id" ], results_exact ))
1908- ids_approx = list (map (lambda x : x ["id" ], results_approx ))
1910+ def extract_id (hit : dict , id_field : str ) -> Tuple [str , str ]:
1911+ """Extract document ID from a Vespa hit."""
1912+
1913+ # id as specified by field
1914+ fields = hit .get ("fields" , {})
1915+ if id_field in fields :
1916+ return fields [id_field ], "id_field"
1917+
1918+ # document id
1919+ id = hit .get ("id" , "" )
1920+ if "::" in id :
1921+ return id , "document_id"
1922+
1923+ # internal id
1924+ if id .startswith (
1925+ "index:"
1926+ ): # id is an internal id of the form index:[source]/[node-index]/[hex-gid], return hex-gid
1927+ return id .split ("/" , 2 )[2 ], "internal_id"
1928+
1929+ raise ValueError (f"Could not extract a document id from hit: { hit } " )
1930+
1931+ ids_exact = list (map (lambda x : extract_id (x , self .id_field )[0 ], results_exact ))
1932+ ids_approx = list (
1933+ map (lambda x : extract_id (x , self .id_field )[0 ], results_approx )
1934+ )
1935+
1936+ id_types = set ()
1937+ id_types .update (map (lambda x : extract_id (x , self .id_field )[1 ], results_exact ))
1938+ id_types .update (map (lambda x : extract_id (x , self .id_field )[1 ], results_approx ))
1939+ if len (id_types ) > 1 :
1940+ print (
1941+ f"Warning: Multiple id types obtained for hits: { id_types } . The recall computation will not be reliable. Please specify id_field correctly."
1942+ )
19091943
19101944 if len (ids_exact ) != self .hits :
19111945 print (
@@ -2125,6 +2159,7 @@ class VespaNNParameterOptimizer:
21252159 benchmark_time_limit (int): Time in milliseconds to spend per bucket benchmark. Defaults to 5000.
21262160 recall_query_limit(int): Number of queries per bucket to compute the recall for. Defaults to 20.
21272161 max_concurrent(int): Number of queries to execute concurrently during benchmark/recall calculation. Defaults to 10.
2162+ id_field (str): Name of the field containing a unique id for recall computation. Defaults to "id".
21282163 """
21292164
21302165 def __init__ (
@@ -2137,6 +2172,7 @@ def __init__(
21372172 benchmark_time_limit : int = 5000 ,
21382173 recall_query_limit : int = 20 ,
21392174 max_concurrent : int = 10 ,
2175+ id_field : str = "id" ,
21402176 ):
21412177 self .app = app
21422178 self .queries = queries
@@ -2150,6 +2186,7 @@ def __init__(
21502186 self .benchmark_time_limit = benchmark_time_limit
21512187 self .recall_query_limit = recall_query_limit
21522188 self .max_concurrent = max_concurrent
2189+ self .id_field = id_field
21532190
21542191 def get_bucket_interval_width (self ) -> float :
21552192 """
@@ -2512,7 +2549,12 @@ def compute_average_recalls(self, **kwargs) -> BucketedMetricResults:
25122549 end = "" ,
25132550 )
25142551 recall_evaluator = VespaNNRecallEvaluator (
2515- bucket , self .hits , self .app , self .recall_query_limit , ** kwargs
2552+ bucket ,
2553+ self .hits ,
2554+ self .app ,
2555+ self .recall_query_limit ,
2556+ self .id_field ,
2557+ ** kwargs ,
25162558 )
25172559 recall_list = recall_evaluator .run ()
25182560 results .append (recall_list )
0 commit comments