Skip to content

Commit 30544c0

Browse files
committed
Add query index params
These parameters to search() let you control the query-time parameters for index operations. They allow you to adjust the speed/recall tradeoff when querying. If these parameters aren't specified the default for the index will be used.
1 parent 443961b commit 30544c0

File tree

3 files changed

+173
-39
lines changed

3 files changed

+173
-39
lines changed

nbs/00_vector.ipynb

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 97,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 98,
24+
"execution_count": null,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -31,7 +31,7 @@
3131
},
3232
{
3333
"cell_type": "code",
34-
"execution_count": 99,
34+
"execution_count": null,
3535
"metadata": {},
3636
"outputs": [],
3737
"source": [
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"cell_type": "code",
45-
"execution_count": 100,
45+
"execution_count": null,
4646
"metadata": {},
4747
"outputs": [],
4848
"source": [
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": 101,
55+
"execution_count": null,
5656
"metadata": {},
5757
"outputs": [],
5858
"source": [
@@ -73,7 +73,7 @@
7373
},
7474
{
7575
"cell_type": "code",
76-
"execution_count": 102,
76+
"execution_count": null,
7777
"metadata": {},
7878
"outputs": [],
7979
"source": [
@@ -146,7 +146,7 @@
146146
},
147147
{
148148
"cell_type": "code",
149-
"execution_count": 103,
149+
"execution_count": null,
150150
"metadata": {},
151151
"outputs": [],
152152
"source": [
@@ -265,6 +265,42 @@
265265
" .format(index_name=index_name_quoted, table_name=table_name_quoted, column_name=column_name_quoted, with_clause=with_clause)\n"
266266
]
267267
},
268+
{
269+
"attachments": {},
270+
"cell_type": "markdown",
271+
"metadata": {},
272+
"source": [
273+
"# Query Params"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": null,
279+
"metadata": {},
280+
"outputs": [],
281+
"source": [
282+
"#| export\n",
283+
"\n",
284+
"class QueryParams:\n",
285+
" def __init__(self, params: dict[str, Any]) -> None:\n",
286+
" self.params = params\n",
287+
" \n",
288+
" def get_statements(self) -> List[str]:\n",
289+
" return [\"SET LOCAL \" + key + \" = \" + str(value) for key, value in self.params.items()]\n",
290+
"\n",
291+
"class TimescaleVectorIndexParams(QueryParams):\n",
292+
" def __init__(self, search_list_size: int) -> None:\n",
293+
" super().__init__({\"tsv.query_search_list_size\": search_list_size})\n",
294+
"\n",
295+
"class IvfflatIndexParams(QueryParams):\n",
296+
" def __init__(self, probes: int) -> None:\n",
297+
" super().__init__({\"ivfflat.probes\": probes})\n",
298+
"\n",
299+
"class HNSWIndexParams(QueryParams):\n",
300+
" def __init__(self, ef_search: int) -> None:\n",
301+
" super().__init__({\"hnsw.ef_search\": ef_search})"
302+
]
303+
},
268304
{
269305
"attachments": {},
270306
"cell_type": "markdown",
@@ -275,7 +311,7 @@
275311
},
276312
{
277313
"cell_type": "code",
278-
"execution_count": 104,
314+
"execution_count": null,
279315
"metadata": {},
280316
"outputs": [],
281317
"source": [
@@ -290,7 +326,7 @@
290326
},
291327
{
292328
"cell_type": "code",
293-
"execution_count": 105,
329+
"execution_count": null,
294330
"metadata": {},
295331
"outputs": [],
296332
"source": [
@@ -388,7 +424,7 @@
388424
},
389425
{
390426
"cell_type": "code",
391-
"execution_count": 106,
427+
"execution_count": null,
392428
"metadata": {},
393429
"outputs": [],
394430
"source": [
@@ -534,7 +570,7 @@
534570
},
535571
{
536572
"cell_type": "code",
537-
"execution_count": 107,
573+
"execution_count": null,
538574
"metadata": {},
539575
"outputs": [],
540576
"source": [
@@ -848,7 +884,7 @@
848884
},
849885
{
850886
"cell_type": "code",
851-
"execution_count": 108,
887+
"execution_count": null,
852888
"metadata": {},
853889
"outputs": [
854890
{
@@ -876,7 +912,7 @@
876912
"Generates a query to create the tables, indexes, and extensions needed to store the vector data."
877913
]
878914
},
879-
"execution_count": 108,
915+
"execution_count": null,
880916
"metadata": {},
881917
"output_type": "execute_result"
882918
}
@@ -895,7 +931,7 @@
895931
},
896932
{
897933
"cell_type": "code",
898-
"execution_count": 109,
934+
"execution_count": null,
899935
"metadata": {},
900936
"outputs": [],
901937
"source": [
@@ -1128,6 +1164,7 @@
11281164
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n",
11291165
" predicates: Optional[Predicates] = None,\n",
11301166
" uuid_time_filter: Optional[UUIDTimeRange] = None,\n",
1167+
" query_params: Optional[QueryParams] = None\n",
11311168
" ): \n",
11321169
" \"\"\"\n",
11331170
" Retrieves similar records using a similarity query.\n",
@@ -1149,13 +1186,22 @@
11491186
" \"\"\"\n",
11501187
" (query, params) = self.builder.search_query(\n",
11511188
" query_embedding, limit, filter, predicates, uuid_time_filter)\n",
1152-
" async with await self.connect() as pool:\n",
1153-
" return await pool.fetch(query, *params)"
1189+
" if query_params is not None:\n",
1190+
" async with await self.connect() as pool:\n",
1191+
" async with pool.transaction():\n",
1192+
" #Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588\n",
1193+
" statements = query_params.get_statements()\n",
1194+
" for statement in statements:\n",
1195+
" await pool.execute(statement)\n",
1196+
" return await pool.fetch(query, *params)\n",
1197+
" else:\n",
1198+
" async with await self.connect() as pool:\n",
1199+
" return await pool.fetch(query, *params)"
11541200
]
11551201
},
11561202
{
11571203
"cell_type": "code",
1158-
"execution_count": 110,
1204+
"execution_count": null,
11591205
"metadata": {},
11601206
"outputs": [
11611207
{
@@ -1183,7 +1229,7 @@
11831229
"Creates necessary tables."
11841230
]
11851231
},
1186-
"execution_count": 110,
1232+
"execution_count": null,
11871233
"metadata": {},
11881234
"output_type": "execute_result"
11891235
}
@@ -1194,7 +1240,7 @@
11941240
},
11951241
{
11961242
"cell_type": "code",
1197-
"execution_count": 111,
1243+
"execution_count": null,
11981244
"metadata": {},
11991245
"outputs": [
12001246
{
@@ -1222,7 +1268,7 @@
12221268
"Creates necessary tables."
12231269
]
12241270
},
1225-
"execution_count": 111,
1271+
"execution_count": null,
12261272
"metadata": {},
12271273
"output_type": "execute_result"
12281274
}
@@ -1233,9 +1279,21 @@
12331279
},
12341280
{
12351281
"cell_type": "code",
1236-
"execution_count": 112,
1282+
"execution_count": null,
12371283
"metadata": {},
12381284
"outputs": [
1285+
{
1286+
"name": "stderr",
1287+
"output_type": "stream",
1288+
"text": [
1289+
"/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n",
1290+
"Returns \n",
1291+
"-------- in \n",
1292+
"Retrieves similar records using a similarity query.\n",
1293+
"...\n",
1294+
" else: warn(msg)\n"
1295+
]
1296+
},
12391297
{
12401298
"data": {
12411299
"text/markdown": [
@@ -1248,7 +1306,8 @@
12481306
"> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
12491307
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
12501308
"> ne, predicates:Optional[__main__.Predicates]=None,\n",
1251-
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n",
1309+
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n",
1310+
"> query_params:Optional[__main__.QueryParams]=None)\n",
12521311
"\n",
12531312
"Retrieves similar records using a similarity query.\n",
12541313
"\n",
@@ -1259,6 +1318,7 @@
12591318
"| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n",
12601319
"| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n",
12611320
"| uuid_time_filter | Optional | None | |\n",
1321+
"| query_params | Optional | None | |\n",
12621322
"| **Returns** | **List: List of similar records.** | | |"
12631323
],
12641324
"text/plain": [
@@ -1271,7 +1331,8 @@
12711331
"> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
12721332
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
12731333
"> ne, predicates:Optional[__main__.Predicates]=None,\n",
1274-
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n",
1334+
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n",
1335+
"> query_params:Optional[__main__.QueryParams]=None)\n",
12751336
"\n",
12761337
"Retrieves similar records using a similarity query.\n",
12771338
"\n",
@@ -1282,10 +1343,11 @@
12821343
"| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n",
12831344
"| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n",
12841345
"| uuid_time_filter | Optional | None | |\n",
1346+
"| query_params | Optional | None | |\n",
12851347
"| **Returns** | **List: List of similar records.** | | |"
12861348
]
12871349
},
1288-
"execution_count": 112,
1350+
"execution_count": null,
12891351
"metadata": {},
12901352
"output_type": "execute_result"
12911353
}
@@ -1296,7 +1358,7 @@
12961358
},
12971359
{
12981360
"cell_type": "code",
1299-
"execution_count": 117,
1361+
"execution_count": null,
13001362
"metadata": {},
13011363
"outputs": [],
13021364
"source": [
@@ -1317,7 +1379,7 @@
13171379
},
13181380
{
13191381
"cell_type": "code",
1320-
"execution_count": 118,
1382+
"execution_count": null,
13211383
"metadata": {},
13221384
"outputs": [],
13231385
"source": [
@@ -1564,6 +1626,10 @@
15641626
"assert len(rec) == 0\n",
15651627
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n",
15661628
"assert len(rec) == 1\n",
1629+
"rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n",
1630+
"assert len(rec) == 2\n",
1631+
"rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n",
1632+
"assert len(rec) == 2\n",
15671633
"await vec.drop_table()\n",
15681634
"await vec.close()"
15691635
]
@@ -1883,6 +1949,7 @@
18831949
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n",
18841950
" predicates: Optional[Predicates] = None,\n",
18851951
" uuid_time_filter: Optional[UUIDTimeRange] = None,\n",
1952+
" query_params: Optional[QueryParams] = None,\n",
18861953
" ):\n",
18871954
" \"\"\"\n",
18881955
" Retrieves similar records using a similarity query.\n",
@@ -1910,6 +1977,11 @@
19101977
" (query, params) = self.builder.search_query(\n",
19111978
" query_embedding_np, limit, filter, predicates, uuid_time_filter)\n",
19121979
" query, params = self._translate_to_pyformat(query, params)\n",
1980+
"\n",
1981+
" if query_params is not None:\n",
1982+
" prefix = \"; \".join(query_params.get_statements())\n",
1983+
" query = f\"{prefix}; {query}\"\n",
1984+
" \n",
19131985
" with self.connect() as conn:\n",
19141986
" with conn.cursor() as cur:\n",
19151987
" cur.execute(query, params)\n",
@@ -2021,7 +2093,8 @@
20212093
"> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
20222094
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
20232095
"> e, predicates:Optional[__main__.Predicates]=None,\n",
2024-
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n",
2096+
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n",
2097+
"> query_params:Optional[__main__.QueryParams]=None)\n",
20252098
"\n",
20262099
"Retrieves similar records using a similarity query.\n",
20272100
"\n",
@@ -2032,6 +2105,7 @@
20322105
"| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n",
20332106
"| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n",
20342107
"| uuid_time_filter | Optional | None | |\n",
2108+
"| query_params | Optional | None | |\n",
20352109
"| **Returns** | **List: List of similar records.** | | |"
20362110
],
20372111
"text/plain": [
@@ -2044,7 +2118,8 @@
20442118
"> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
20452119
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
20462120
"> e, predicates:Optional[__main__.Predicates]=None,\n",
2047-
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n",
2121+
"> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n",
2122+
"> query_params:Optional[__main__.QueryParams]=None)\n",
20482123
"\n",
20492124
"Retrieves similar records using a similarity query.\n",
20502125
"\n",
@@ -2055,6 +2130,7 @@
20552130
"| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n",
20562131
"| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n",
20572132
"| uuid_time_filter | Optional | None | |\n",
2133+
"| query_params | Optional | None | |\n",
20582134
"| **Returns** | **List: List of similar records.** | | |"
20592135
]
20602136
},
@@ -2314,6 +2390,10 @@
23142390
"assert len(rec) == 0\n",
23152391
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n",
23162392
"assert len(rec) == 1\n",
2393+
"rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n",
2394+
"assert len(rec) == 2\n",
2395+
"rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n",
2396+
"assert len(rec) == 2\n",
23172397
"vec.drop_table()\n",
23182398
"vec.close()"
23192399
]

0 commit comments

Comments
 (0)