|
247 | 247 | "\n",
|
248 | 248 | " return (where, params) \n",
|
249 | 249 | "\n",
|
250 |
| - " def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
| 250 | + " def search_query(self, query_embedding: Optional[List[float]], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
251 | 251 | " \"\"\"\n",
|
252 | 252 | " Generates a similarity query.\n",
|
253 | 253 | "\n",
|
254 | 254 | " Args:\n",
|
255 |
| - " query_embedding (List[float]): The query embedding vector.\n", |
| 255 | + " query_embedding (Optiona[List[float]], optional): The query embedding vector.\n", |
256 | 256 | " k (int, optional): The number of nearest neighbors to retrieve. Default is 10.\n",
|
257 | 257 | " filter (Optional[dict], optional): A filter for metadata. Default is None.\n",
|
258 | 258 | "\n",
|
259 | 259 | " Returns:\n",
|
260 | 260 | " Tuple[str, List]: A tuple containing the query and parameters.\n",
|
261 | 261 | " \"\"\"\n",
|
262 | 262 | " params = []\n",
|
263 |
| - " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n", |
264 |
| - " params = params + [query_embedding]\n", |
| 263 | + " if query_embedding is not None:\n", |
| 264 | + " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n", |
| 265 | + " params = params + [query_embedding]\n", |
| 266 | + " order_by_clause = \"ORDER BY {distance} ASC\".format(distance=distance)\n", |
| 267 | + " else:\n", |
| 268 | + " distance = \"-1.0\"\n", |
| 269 | + " order_by_clause = \"\"\n", |
265 | 270 | "\n",
|
266 | 271 | " (where, params) = self._where_clause_for_filter(params, filter)\n",
|
267 | 272 | "\n",
|
|
272 | 277 | " {table_name}\n",
|
273 | 278 | " WHERE \n",
|
274 | 279 | " {where}\n",
|
275 |
| - " ORDER BY {distance} ASC\n", |
| 280 | + " {order_by_clause}\n", |
276 | 281 | " LIMIT {k}\n",
|
277 |
| - " '''.format(distance=distance, where=where, table_name=self._quote_ident(self.table_name), k=k)\n", |
| 282 | + " '''.format(distance=distance, order_by_clause=order_by_clause, where=where, table_name=self._quote_ident(self.table_name), k=k)\n", |
278 | 283 | " return (query, params)"
|
279 | 284 | ]
|
280 | 285 | },
|
|
504 | 509 | " await pool.execute(query)\n",
|
505 | 510 | "\n",
|
506 | 511 | " async def search(self, \n",
|
507 |
| - " query_embedding: List[float], # vector to search for\n", |
| 512 | + " query_embedding: Optional[List[float]] = None, # vector to search for\n", |
508 | 513 | " k: int=10, # The number of nearest neighbors to retrieve. Default is 10.\n",
|
509 | 514 | " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.\n",
|
510 | 515 | " \"\"\"\n",
|
|
622 | 627 | "\n",
|
623 | 628 | "### Async.search\n",
|
624 | 629 | "\n",
|
625 |
| - "> Async.search (query_embedding:List[float], k:int=10,\n", |
| 630 | + "> Async.search (query_embedding:Optional[List[float]]=None, k:int=10,\n", |
626 | 631 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
|
627 | 632 | "> ne)\n",
|
628 | 633 | "\n",
|
|
638 | 643 | "\n",
|
639 | 644 | "### Async.search\n",
|
640 | 645 | "\n",
|
641 |
| - "> Async.search (query_embedding:List[float], k:int=10,\n", |
| 646 | + "> Async.search (query_embedding:Optional[List[float]]=None, k:int=10,\n", |
642 | 647 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
|
643 | 648 | "> ne)\n",
|
644 | 649 | "\n",
|
|
715 | 720 | "assert len(rec) == 10\n",
|
716 | 721 | "rec = await vec.search([1.0, 2.0], k=4)\n",
|
717 | 722 | "assert len(rec) == 4\n",
|
| 723 | + "rec = await vec.search(k=4)\n", |
| 724 | + "assert len(rec) == 4\n", |
718 | 725 | "rec = await vec.search([1.0, 2.0], k=4, filter={\"key2\":\"val2\"})\n",
|
719 | 726 | "assert len(rec) == 1\n",
|
720 | 727 | "rec = await vec.search([1.0, 2.0], k=4, filter={\"key2\":\"does not exist\"})\n",
|
|
725 | 732 | "assert len(rec) == 1\n",
|
726 | 733 | "rec = await vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n",
|
727 | 734 | "assert len(rec) == 0\n",
|
728 |
| - "\n", |
| 735 | + "rec = await vec.search(k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n", |
| 736 | + "assert len(rec) == 0\n", |
729 | 737 | "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
|
730 | 738 | "assert len(rec) == 2\n",
|
| 739 | + "rec = await vec.search(k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 740 | + "assert len(rec) == 2\n", |
731 | 741 | "\n",
|
732 | 742 | "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}, {\"no such key\": \"no such val\"}])\n",
|
733 | 743 | "assert len(rec) == 2\n",
|
|
1010 | 1020 | " with conn.cursor() as cur:\n",
|
1011 | 1021 | " cur.execute(query)\n",
|
1012 | 1022 | "\n",
|
1013 |
| - " def search(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n", |
| 1023 | + " def search(self, query_embedding: Optional[List[float]]=None, k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n", |
1014 | 1024 | " \"\"\"\n",
|
1015 | 1025 | " Retrieves similar records using a similarity query.\n",
|
1016 | 1026 | "\n",
|
|
1022 | 1032 | " Returns:\n",
|
1023 | 1033 | " List: List of similar records.\n",
|
1024 | 1034 | " \"\"\"\n",
|
1025 |
| - " (query, params) = self.builder.search_query(np.array(query_embedding), k, filter)\n", |
| 1035 | + " if query_embedding is not None:\n", |
| 1036 | + " query_embedding = np.array(query_embedding)\n", |
| 1037 | + " \n", |
| 1038 | + " (query, params) = self.builder.search_query(query_embedding, k, filter)\n", |
1026 | 1039 | " query, params = self._translate_to_pyformat(query, params)\n",
|
1027 | 1040 | " with self.connect() as conn:\n",
|
1028 | 1041 | " with conn.cursor() as cur:\n",
|
|
1140 | 1153 | "\n",
|
1141 | 1154 | "### Sync.search\n",
|
1142 | 1155 | "\n",
|
1143 |
| - "> Sync.search (query_embedding:List[float], k:int=10,\n", |
| 1156 | + "> Sync.search (query_embedding:Optional[List[float]]=None, k:int=10,\n", |
1144 | 1157 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
|
1145 | 1158 | "> e)\n",
|
1146 | 1159 | "\n",
|
|
1161 | 1174 | "\n",
|
1162 | 1175 | "### Sync.search\n",
|
1163 | 1176 | "\n",
|
1164 |
| - "> Sync.search (query_embedding:List[float], k:int=10,\n", |
| 1177 | + "> Sync.search (query_embedding:Optional[List[float]]=None, k:int=10,\n", |
1165 | 1178 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
|
1166 | 1179 | "> e)\n",
|
1167 | 1180 | "\n",
|
|
1246 | 1259 | "assert len(rec) == 10\n",
|
1247 | 1260 | "rec = vec.search([1.0, 2.0], k=4)\n",
|
1248 | 1261 | "assert len(rec) == 4\n",
|
| 1262 | + "rec = vec.search(k=4)\n", |
| 1263 | + "assert len(rec) == 4\n", |
1249 | 1264 | "rec = vec.search([1.0, 2.0], k=4, filter={\"key2\":\"val2\"})\n",
|
1250 | 1265 | "assert len(rec) == 1\n",
|
1251 | 1266 | "rec = vec.search([1.0, 2.0], k=4, filter={\"key2\":\"does not exist\"})\n",
|
1252 | 1267 | "assert len(rec) == 0\n",
|
| 1268 | + "rec = vec.search(k=4, filter={\"key2\":\"does not exist\"})\n", |
| 1269 | + "assert len(rec) == 0\n", |
1253 | 1270 | "rec = vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\"})\n",
|
1254 | 1271 | "assert len(rec) == 1\n",
|
1255 | 1272 | "rec = vec.search([1.0, 2.0], filter={\"key_1\":\"val_1\", \"key_2\":\"val_2\"})\n",
|
|
0 commit comments