|
138 | 138 | " \"\"\"\n",
|
139 | 139 | " return \"SELECT 1 FROM {table_name} LIMIT 1\".format(table_name=self._quote_ident(self.table_name))\n",
|
140 | 140 | "\n",
|
141 |
| - " #| export\n", |
142 | 141 | " def get_upsert_query(self):\n",
|
143 | 142 | " \"\"\"\n",
|
144 | 143 | " Generates an upsert query.\n",
|
|
188 | 187 | " def delete_all_query(self):\n",
|
189 | 188 | " return \"TRUNCATE {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
|
190 | 189 | "\n",
|
| 190 | + " def delete_by_id_query(self, id: uuid.UUID) -> Tuple[str, List]:\n", |
| 191 | + " query = \"DELETE FROM {table_name} WHERE id = $1;\".format(table_name=self._quote_ident(self.table_name))\n", |
| 192 | + " return (query, [id])\n", |
| 193 | + "\n", |
| 194 | + " def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n", |
| 195 | + " params = []\n", |
| 196 | + " (where, params) = self._where_clause_for_filter(params, filter)\n", |
| 197 | + " query = \"DELETE FROM {table_name} WHERE {where};\".format(table_name=self._quote_ident(self.table_name), where=where)\n", |
| 198 | + " return (query, params) \n", |
| 199 | + "\n", |
191 | 200 | " def drop_table_query(self):\n",
|
192 | 201 | " return \"DROP TABLE IF EXISTS {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
|
193 | 202 | " \n",
|
|
222 | 231 | " return \"CREATE INDEX {index_name} ON {table_name} USING ivfflat ({column_name} {index_method}) WITH (lists = {num_lists});\"\\\n",
|
223 | 232 | " .format(index_name=self._get_embedding_index_name(), table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)\n",
|
224 | 233 | "\n",
|
| 234 | + " def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]]) -> Tuple[str, List]:\n", |
| 235 | + " if isinstance(filter, dict):\n", |
| 236 | + " where = \"metadata @> ${index}\".format(index=len(params)+1)\n", |
| 237 | + " json_object = json.dumps(filter)\n", |
| 238 | + " params = params + [json_object]\n", |
| 239 | + " elif isinstance(filter, list):\n", |
| 240 | + " any_params = []\n", |
| 241 | + " for idx, filter_dict in enumerate(filter, start=len(params) + 1):\n", |
| 242 | + " any_params.append(json.dumps(filter_dict))\n", |
| 243 | + " where = \"metadata @> ANY(${index}::jsonb[])\".format(index=len(params) + 1)\n", |
| 244 | + " params = params + [any_params]\n", |
| 245 | + " else:\n", |
| 246 | + " where = \"TRUE\"\n", |
| 247 | + "\n", |
| 248 | + " return (where, params) \n", |
| 249 | + "\n", |
225 | 250 | " def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
|
226 | 251 | " \"\"\"\n",
|
227 | 252 | " Generates a similarity query.\n",
|
|
238 | 263 | " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
|
239 | 264 | " params = params + [query_embedding]\n",
|
240 | 265 | "\n",
|
241 |
| - " if isinstance(filter, dict):\n", |
242 |
| - " where = \"metadata @> ${index}\".format(index=len(params)+1)\n", |
243 |
| - " json_object = json.dumps(filter)\n", |
244 |
| - " params = params + [json_object]\n", |
245 |
| - " elif isinstance(filter, list):\n", |
246 |
| - " any_params = []\n", |
247 |
| - " for idx, filter_dict in enumerate(filter, start=len(params) + 1):\n", |
248 |
| - " any_params.append(json.dumps(filter_dict))\n", |
249 |
| - " where = \"metadata @> ANY(${index}::jsonb[])\".format(index=len(params) + 1)\n", |
250 |
| - " params = params + [any_params]\n", |
251 |
| - " else:\n", |
252 |
| - " where = \"TRUE\"\n", |
253 |
| - " \n", |
| 266 | + " (where, params) = self._where_clause_for_filter(params, filter)\n", |
| 267 | + "\n", |
254 | 268 | " query = '''\n",
|
255 | 269 | " SELECT\n",
|
256 | 270 | " id, metadata, contents, embedding, {distance} as distance\n",
|
|
421 | 435 | " query = self.builder.delete_all_query()\n",
|
422 | 436 | " async with await self.connect() as pool:\n",
|
423 | 437 | " await pool.execute(query)\n",
|
424 |
| - " \n", |
| 438 | + "\n", |
| 439 | + " async def delete_by_id(self, id: uuid.UUID):\n", |
| 440 | + " \"\"\"\n", |
| 441 | + " Delete records by id.\n", |
| 442 | + " \"\"\"\n", |
| 443 | + " (query, params) = self.builder.delete_by_id_query(id)\n", |
| 444 | + " async with await self.connect() as pool:\n", |
| 445 | + " return await pool.fetch(query, *params)\n", |
| 446 | + "\n", |
| 447 | + " async def delete_by_metadata(self, filter: Union[Dict[str, str], List[Dict[str, str]]]):\n", |
| 448 | + " \"\"\"\n", |
| 449 | + " Delete records by metadata filters.\n", |
| 450 | + " \"\"\"\n", |
| 451 | + " (query, params) = self.builder.delete_by_metadata_query(filter)\n", |
| 452 | + " async with await self.connect() as pool:\n", |
| 453 | + " return await pool.fetch(query, *params)\n", |
| 454 | + "\n", |
| 455 | + "\n", |
425 | 456 | " async def drop_table(self):\n",
|
426 | 457 | " \"\"\"\n",
|
427 | 458 | " Drops the table\n",
|
|
723 | 754 | "except BaseException as e:\n",
|
724 | 755 | " pass\n",
|
725 | 756 | "\n",
|
| 757 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 758 | + "assert len(rec) == 2\n", |
| 759 | + "await vec.delete_by_id(rec[0][SEARCH_RESULT_ID_IDX])\n", |
| 760 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 761 | + "assert len(rec) == 1\n", |
| 762 | + "await vec.delete_by_metadata([{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 763 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 764 | + "assert len(rec) == 0\n", |
| 765 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n", |
| 766 | + "assert len(rec) == 4\n", |
| 767 | + "await vec.delete_by_metadata([{\"key2\":\"val\"}])\n", |
| 768 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n", |
| 769 | + "assert len(rec) == 0\n", |
726 | 770 | "\n",
|
| 771 | + "assert not await vec.table_is_empty()\n", |
727 | 772 | "await vec.delete_all()\n",
|
728 | 773 | "assert await vec.table_is_empty()\n",
|
729 | 774 | "\n",
|
|
889 | 934 | " with self.connect() as conn:\n",
|
890 | 935 | " with conn.cursor() as cur:\n",
|
891 | 936 | " cur.execute(query)\n",
|
| 937 | + " \n", |
| 938 | + " def delete_by_id(self, id: uuid.UUID):\n", |
| 939 | + " \"\"\"\n", |
| 940 | + " Delete records by id.\n", |
| 941 | + " \"\"\"\n", |
| 942 | + " (query, params) = self.builder.delete_by_id_query(id)\n", |
| 943 | + " query, params = self._translate_to_pyformat(query, params)\n", |
| 944 | + " with self.connect() as conn:\n", |
| 945 | + " with conn.cursor() as cur:\n", |
| 946 | + " cur.execute(query, params)\n", |
| 947 | + "\n", |
| 948 | + " def delete_by_metadata(self, filter: Union[Dict[str, str], List[Dict[str, str]]]):\n", |
| 949 | + " \"\"\"\n", |
| 950 | + " Delete records by metadata filters.\n", |
| 951 | + " \"\"\"\n", |
| 952 | + " (query, params) = self.builder.delete_by_metadata_query(filter)\n", |
| 953 | + " query, params = self._translate_to_pyformat(query, params)\n", |
| 954 | + " with self.connect() as conn:\n", |
| 955 | + " with conn.cursor() as cur:\n", |
| 956 | + " cur.execute(query, params)\n", |
892 | 957 | "\n",
|
893 | 958 | " def drop_table(self):\n",
|
894 | 959 | " \"\"\"\n",
|
|
1224 | 1289 | "assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
|
1225 | 1290 | "assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
|
1226 | 1291 | "\n",
|
| 1292 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 1293 | + "len(rec) == 2\n", |
| 1294 | + "vec.delete_by_id(rec[0][SEARCH_RESULT_ID_IDX])\n", |
| 1295 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 1296 | + "assert len(rec) == 1\n", |
| 1297 | + "vec.delete_by_metadata([{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 1298 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 1299 | + "assert len(rec) == 0\n", |
| 1300 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n", |
| 1301 | + "assert len(rec) == 4\n", |
| 1302 | + "vec.delete_by_metadata([{\"key2\":\"val\"}])\n", |
| 1303 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n", |
| 1304 | + "len(rec) == 0\n", |
| 1305 | + "\n", |
| 1306 | + "assert not vec.table_is_empty()\n", |
1227 | 1307 | "vec.delete_all()\n",
|
1228 | 1308 | "assert vec.table_is_empty()\n",
|
1229 | 1309 | "\n",
|
|
0 commit comments