Skip to content

Commit d985f22

Browse files
committed
Make the embedding optional in the search
1 parent 89a57ea commit d985f22

File tree

4 files changed

+70
-23
lines changed

4 files changed

+70
-23
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,9 @@ await vec.drop_embedding_index()
191191
Please note the community is actively working on new indexing methods
192192
for embeddings. As they become available, we will add them to our client
193193
as well.
194+
195+
## Development
196+
197+
Please note that this project is developed with
198+
[nbdev](https://nbdev.fast.ai/). Please see that website for the
199+
development process.

nbs/00_vector.ipynb

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,21 +247,26 @@
247247
"\n",
248248
" return (where, params) \n",
249249
"\n",
250-
" def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
250+
" def search_query(self, query_embedding: Optional[List[float]], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
251251
" \"\"\"\n",
252252
" Generates a similarity query.\n",
253253
"\n",
254254
" Args:\n",
255-
" query_embedding (List[float]): The query embedding vector.\n",
255+
" query_embedding (Optiona[List[float]], optional): The query embedding vector.\n",
256256
" k (int, optional): The number of nearest neighbors to retrieve. Default is 10.\n",
257257
" filter (Optional[dict], optional): A filter for metadata. Default is None.\n",
258258
"\n",
259259
" Returns:\n",
260260
" Tuple[str, List]: A tuple containing the query and parameters.\n",
261261
" \"\"\"\n",
262262
" params = []\n",
263-
" distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
264-
" params = params + [query_embedding]\n",
263+
" if query_embedding is not None:\n",
264+
" distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
265+
" params = params + [query_embedding]\n",
266+
" order_by_clause = \"ORDER BY {distance} ASC\".format(distance=distance)\n",
267+
" else:\n",
268+
" distance = \"-1.0\"\n",
269+
" order_by_clause = \"\"\n",
265270
"\n",
266271
" (where, params) = self._where_clause_for_filter(params, filter)\n",
267272
"\n",
@@ -272,9 +277,9 @@
272277
" {table_name}\n",
273278
" WHERE \n",
274279
" {where}\n",
275-
" ORDER BY {distance} ASC\n",
280+
" {order_by_clause}\n",
276281
" LIMIT {k}\n",
277-
" '''.format(distance=distance, where=where, table_name=self._quote_ident(self.table_name), k=k)\n",
282+
" '''.format(distance=distance, order_by_clause=order_by_clause, where=where, table_name=self._quote_ident(self.table_name), k=k)\n",
278283
" return (query, params)"
279284
]
280285
},
@@ -504,7 +509,7 @@
504509
" await pool.execute(query)\n",
505510
"\n",
506511
" async def search(self, \n",
507-
" query_embedding: List[float], # vector to search for\n",
512+
" query_embedding: Optional[List[float]] = None, # vector to search for\n",
508513
" k: int=10, # The number of nearest neighbors to retrieve. Default is 10.\n",
509514
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.\n",
510515
" \"\"\"\n",
@@ -622,7 +627,7 @@
622627
"\n",
623628
"### Async.search\n",
624629
"\n",
625-
"> Async.search (query_embedding:List[float], k:int=10,\n",
630+
"> Async.search (query_embedding:Optional[List[float]]=None, k:int=10,\n",
626631
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
627632
"> ne)\n",
628633
"\n",
@@ -638,7 +643,7 @@
638643
"\n",
639644
"### Async.search\n",
640645
"\n",
641-
"> Async.search (query_embedding:List[float], k:int=10,\n",
646+
"> Async.search (query_embedding:Optional[List[float]]=None, k:int=10,\n",
642647
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
643648
"> ne)\n",
644649
"\n",
@@ -715,6 +720,8 @@
715720
"assert len(rec) == 10\n",
716721
"rec = await vec.search([1.0, 2.0], k=4)\n",
717722
"assert len(rec) == 4\n",
723+
"rec = await vec.search(k=4)\n",
724+
"assert len(rec) == 4\n",
718725
"rec = await vec.search([1.0, 2.0], k=4, filter={\"key2\":\"val2\"})\n",
719726
"assert len(rec) == 1\n",
720727
"rec = await vec.search([1.0, 2.0], k=4, filter={\"key2\":\"does not exist\"})\n",
@@ -725,9 +732,12 @@
725732
"assert len(rec) == 1\n",
726733
"rec = await vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n",
727734
"assert len(rec) == 0\n",
728-
"\n",
735+
"rec = await vec.search(k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n",
736+
"assert len(rec) == 0\n",
729737
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
730738
"assert len(rec) == 2\n",
739+
"rec = await vec.search(k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
740+
"assert len(rec) == 2\n",
731741
"\n",
732742
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}, {\"no such key\": \"no such val\"}])\n",
733743
"assert len(rec) == 2\n",
@@ -1010,7 +1020,7 @@
10101020
" with conn.cursor() as cur:\n",
10111021
" cur.execute(query)\n",
10121022
"\n",
1013-
" def search(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n",
1023+
" def search(self, query_embedding: Optional[List[float]]=None, k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n",
10141024
" \"\"\"\n",
10151025
" Retrieves similar records using a similarity query.\n",
10161026
"\n",
@@ -1022,7 +1032,10 @@
10221032
" Returns:\n",
10231033
" List: List of similar records.\n",
10241034
" \"\"\"\n",
1025-
" (query, params) = self.builder.search_query(np.array(query_embedding), k, filter)\n",
1035+
" if query_embedding is not None:\n",
1036+
" query_embedding = np.array(query_embedding)\n",
1037+
" \n",
1038+
" (query, params) = self.builder.search_query(query_embedding, k, filter)\n",
10261039
" query, params = self._translate_to_pyformat(query, params)\n",
10271040
" with self.connect() as conn:\n",
10281041
" with conn.cursor() as cur:\n",
@@ -1140,7 +1153,7 @@
11401153
"\n",
11411154
"### Sync.search\n",
11421155
"\n",
1143-
"> Sync.search (query_embedding:List[float], k:int=10,\n",
1156+
"> Sync.search (query_embedding:Optional[List[float]]=None, k:int=10,\n",
11441157
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
11451158
"> e)\n",
11461159
"\n",
@@ -1161,7 +1174,7 @@
11611174
"\n",
11621175
"### Sync.search\n",
11631176
"\n",
1164-
"> Sync.search (query_embedding:List[float], k:int=10,\n",
1177+
"> Sync.search (query_embedding:Optional[List[float]]=None, k:int=10,\n",
11651178
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
11661179
"> e)\n",
11671180
"\n",
@@ -1246,10 +1259,14 @@
12461259
"assert len(rec) == 10\n",
12471260
"rec = vec.search([1.0, 2.0], k=4)\n",
12481261
"assert len(rec) == 4\n",
1262+
"rec = vec.search(k=4)\n",
1263+
"assert len(rec) == 4\n",
12491264
"rec = vec.search([1.0, 2.0], k=4, filter={\"key2\":\"val2\"})\n",
12501265
"assert len(rec) == 1\n",
12511266
"rec = vec.search([1.0, 2.0], k=4, filter={\"key2\":\"does not exist\"})\n",
12521267
"assert len(rec) == 0\n",
1268+
"rec = vec.search(k=4, filter={\"key2\":\"does not exist\"})\n",
1269+
"assert len(rec) == 0\n",
12531270
"rec = vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\"})\n",
12541271
"assert len(rec) == 1\n",
12551272
"rec = vec.search([1.0, 2.0], filter={\"key_1\":\"val_1\", \"key_2\":\"val_2\"})\n",

nbs/index.ipynb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,22 @@
541541
"metadata": {},
542542
"outputs": [],
543543
"source": []
544+
},
545+
{
546+
"attachments": {},
547+
"cell_type": "markdown",
548+
"metadata": {},
549+
"source": [
550+
"## Development\n",
551+
"\n",
552+
"Please note that this project is developed with [nbdev](https://nbdev.fast.ai/). Please see that website for the development process."
553+
]
554+
},
555+
{
556+
"attachments": {},
557+
"cell_type": "markdown",
558+
"metadata": {},
559+
"source": []
544560
}
545561
],
546562
"metadata": {

timescale_vector/client.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,26 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
172172

173173
return (where, params)
174174

175-
def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:
175+
def search_query(self, query_embedding: Optional[List[float]], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:
176176
"""
177177
Generates a similarity query.
178178
179179
Args:
180-
query_embedding (List[float]): The query embedding vector.
180+
query_embedding (Optiona[List[float]], optional): The query embedding vector.
181181
k (int, optional): The number of nearest neighbors to retrieve. Default is 10.
182182
filter (Optional[dict], optional): A filter for metadata. Default is None.
183183
184184
Returns:
185185
Tuple[str, List]: A tuple containing the query and parameters.
186186
"""
187187
params = []
188-
distance = "embedding {op} ${index}".format(op=self.distance_type, index=len(params)+1)
189-
params = params + [query_embedding]
188+
if query_embedding is not None:
189+
distance = "embedding {op} ${index}".format(op=self.distance_type, index=len(params)+1)
190+
params = params + [query_embedding]
191+
order_by_clause = "ORDER BY {distance} ASC".format(distance=distance)
192+
else:
193+
distance = "-1.0"
194+
order_by_clause = ""
190195

191196
(where, params) = self._where_clause_for_filter(params, filter)
192197

@@ -197,9 +202,9 @@ def search_query(self, query_embedding: List[float], k: int=10, filter: Optional
197202
{table_name}
198203
WHERE
199204
{where}
200-
ORDER BY {distance} ASC
205+
{order_by_clause}
201206
LIMIT {k}
202-
'''.format(distance=distance, where=where, table_name=self._quote_ident(self.table_name), k=k)
207+
'''.format(distance=distance, order_by_clause=order_by_clause, where=where, table_name=self._quote_ident(self.table_name), k=k)
203208
return (query, params)
204209

205210
# %% ../nbs/00_vector.ipynb 11
@@ -369,7 +374,7 @@ async def create_ivfflat_index(self, num_records=None):
369374
await pool.execute(query)
370375

371376
async def search(self,
372-
query_embedding: List[float], # vector to search for
377+
query_embedding: Optional[List[float]] = None, # vector to search for
373378
k: int=10, # The number of nearest neighbors to retrieve. Default is 10.
374379
filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.
375380
"""
@@ -596,7 +601,7 @@ def create_ivfflat_index(self, num_records=None):
596601
with conn.cursor() as cur:
597602
cur.execute(query)
598603

599-
def search(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):
604+
def search(self, query_embedding: Optional[List[float]]=None, k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):
600605
"""
601606
Retrieves similar records using a similarity query.
602607
@@ -608,7 +613,10 @@ def search(self, query_embedding: List[float], k: int=10, filter: Optional[Union
608613
Returns:
609614
List: List of similar records.
610615
"""
611-
(query, params) = self.builder.search_query(np.array(query_embedding), k, filter)
616+
if query_embedding is not None:
617+
query_embedding = np.array(query_embedding)
618+
619+
(query, params) = self.builder.search_query(query_embedding, k, filter)
612620
query, params = self._translate_to_pyformat(query, params)
613621
with self.connect() as conn:
614622
with conn.cursor() as cur:

0 commit comments

Comments
 (0)