Skip to content

Commit 9bacf7f

Browse files
committed
Add delete_by_id and delete_by_metadata
1 parent 138bbff commit 9bacf7f

File tree

5 files changed

+284
-48
lines changed

5 files changed

+284
-48
lines changed

README.md

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,24 @@ Now you can query for similar items:
7878
await vec.search([1.0, 9.0])
7979
```
8080

81-
[<Record id=UUID('0313cfac-07e4-4c01-9651-0917ecb1991c') metadata='{"action": "jump", "animal": "fox"}' contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>,
82-
<Record id=UUID('3bde3dd3-9445-4d9e-b72e-f329d19c380d') metadata='{"animal": "fox"}' contents='the brown fox' embedding=array([1. , 1.3], dtype=float32) distance=0.14489260377438218>]
81+
[<Record id=UUID('1bd6a985-a837-4742-a007-d8a785e7089f') metadata={'action': 'jump', 'animal': 'fox'} contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>,
82+
<Record id=UUID('2e52b4a4-3422-42d7-8e62-fd40731e7ffa') metadata={'animal': 'fox'} contents='the brown fox' embedding=array([1. , 1.3], dtype=float32) distance=0.14489260377438218>]
8383

8484
You can specify the number of records to return.
8585

8686
``` python
8787
await vec.search([1.0, 9.0], k=1)
8888
```
8989

90-
[<Record id=UUID('0313cfac-07e4-4c01-9651-0917ecb1991c') metadata='{"action": "jump", "animal": "fox"}' contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>]
90+
[<Record id=UUID('1bd6a985-a837-4742-a007-d8a785e7089f') metadata={'action': 'jump', 'animal': 'fox'} contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>]
9191

9292
You can also specify a filter on the metadata as a simple dictionary
9393

9494
``` python
9595
await vec.search([1.0, 9.0], k=1, filter={"action": "jump"})
9696
```
9797

98-
[<Record id=UUID('0313cfac-07e4-4c01-9651-0917ecb1991c') metadata='{"action": "jump", "animal": "fox"}' contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>]
98+
[<Record id=UUID('1bd6a985-a837-4742-a007-d8a785e7089f') metadata={'action': 'jump', 'animal': 'fox'} contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>]
9999

100100
You can also specify a list of filter dictionaries, where an item is
101101
returned if it matches any dict
@@ -104,8 +104,8 @@ returned if it matches any dict
104104
await vec.search([1.0, 9.0], k=2, filter=[{"action": "jump"}, {"animal": "fox"}])
105105
```
106106

107-
[<Record id=UUID('0313cfac-07e4-4c01-9651-0917ecb1991c') metadata='{"action": "jump", "animal": "fox"}' contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>,
108-
<Record id=UUID('3bde3dd3-9445-4d9e-b72e-f329d19c380d') metadata='{"animal": "fox"}' contents='the brown fox' embedding=array([1. , 1.3], dtype=float32) distance=0.14489260377438218>]
107+
[<Record id=UUID('1bd6a985-a837-4742-a007-d8a785e7089f') metadata={'action': 'jump', 'animal': 'fox'} contents='jumped over the' embedding=array([ 1. , 10.8], dtype=float32) distance=0.00016793422934946456>,
108+
<Record id=UUID('2e52b4a4-3422-42d7-8e62-fd40731e7ffa') metadata={'animal': 'fox'} contents='the brown fox' embedding=array([1. , 1.3], dtype=float32) distance=0.14489260377438218>]
109109

110110
You can access the fields as follows
111111

@@ -114,13 +114,13 @@ records = await vec.search([1.0, 9.0], k=1, filter={"action": "jump"})
114114
records[0][client.SEARCH_RESULT_ID_IDX]
115115
```
116116

117-
UUID('d282ad19-1a69-4a9d-8a15-6f06262e109a')
117+
UUID('1bd6a985-a837-4742-a007-d8a785e7089f')
118118

119119
``` python
120120
records[0][client.SEARCH_RESULT_METADATA_IDX]
121121
```
122122

123-
'{"action": "jump", "animal": "fox"}'
123+
{'action': 'jump', 'animal': 'fox'}
124124

125125
``` python
126126
records[0][client.SEARCH_RESULT_CONTENTS_IDX]
@@ -140,6 +140,22 @@ records[0][client.SEARCH_RESULT_DISTANCE_IDX]
140140

141141
0.00016793422934946456
142142

143+
You can delete by ID:
144+
145+
``` python
146+
await vec.delete_by_id(records[0][client.SEARCH_RESULT_ID_IDX])
147+
```
148+
149+
[]
150+
151+
Or you can delete by metadata filters:
152+
153+
``` python
154+
await vec.delete_by_metadata({"action": "jump"})
155+
```
156+
157+
[]
158+
143159
To delete all records use:
144160

145161
``` python

nbs/00_vector.ipynb

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
" \"\"\"\n",
139139
" return \"SELECT 1 FROM {table_name} LIMIT 1\".format(table_name=self._quote_ident(self.table_name))\n",
140140
"\n",
141-
" #| export\n",
142141
" def get_upsert_query(self):\n",
143142
" \"\"\"\n",
144143
" Generates an upsert query.\n",
@@ -188,6 +187,16 @@
188187
" def delete_all_query(self):\n",
189188
" return \"TRUNCATE {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
190189
"\n",
190+
" def delete_by_id_query(self, id: uuid.UUID) -> Tuple[str, List]:\n",
191+
" query = \"DELETE FROM {table_name} WHERE id = $1;\".format(table_name=self._quote_ident(self.table_name))\n",
192+
" return (query, [id])\n",
193+
"\n",
194+
" def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n",
195+
" params = []\n",
196+
" (where, params) = self._where_clause_for_filter(params, filter)\n",
197+
" query = \"DELETE FROM {table_name} WHERE {where};\".format(table_name=self._quote_ident(self.table_name), where=where)\n",
198+
" return (query, params) \n",
199+
"\n",
191200
" def drop_table_query(self):\n",
192201
" return \"DROP TABLE IF EXISTS {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
193202
" \n",
@@ -222,6 +231,22 @@
222231
" return \"CREATE INDEX {index_name} ON {table_name} USING ivfflat ({column_name} {index_method}) WITH (lists = {num_lists});\"\\\n",
223232
" .format(index_name=self._get_embedding_index_name(), table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)\n",
224233
"\n",
234+
" def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]]) -> Tuple[str, List]:\n",
235+
" if isinstance(filter, dict):\n",
236+
" where = \"metadata @> ${index}\".format(index=len(params)+1)\n",
237+
" json_object = json.dumps(filter)\n",
238+
" params = params + [json_object]\n",
239+
" elif isinstance(filter, list):\n",
240+
" any_params = []\n",
241+
" for idx, filter_dict in enumerate(filter, start=len(params) + 1):\n",
242+
" any_params.append(json.dumps(filter_dict))\n",
243+
" where = \"metadata @> ANY(${index}::jsonb[])\".format(index=len(params) + 1)\n",
244+
" params = params + [any_params]\n",
245+
" else:\n",
246+
" where = \"TRUE\"\n",
247+
"\n",
248+
" return (where, params) \n",
249+
"\n",
225250
" def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
226251
" \"\"\"\n",
227252
" Generates a similarity query.\n",
@@ -238,19 +263,8 @@
238263
" distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
239264
" params = params + [query_embedding]\n",
240265
"\n",
241-
" if isinstance(filter, dict):\n",
242-
" where = \"metadata @> ${index}\".format(index=len(params)+1)\n",
243-
" json_object = json.dumps(filter)\n",
244-
" params = params + [json_object]\n",
245-
" elif isinstance(filter, list):\n",
246-
" any_params = []\n",
247-
" for idx, filter_dict in enumerate(filter, start=len(params) + 1):\n",
248-
" any_params.append(json.dumps(filter_dict))\n",
249-
" where = \"metadata @> ANY(${index}::jsonb[])\".format(index=len(params) + 1)\n",
250-
" params = params + [any_params]\n",
251-
" else:\n",
252-
" where = \"TRUE\"\n",
253-
" \n",
266+
" (where, params) = self._where_clause_for_filter(params, filter)\n",
267+
"\n",
254268
" query = '''\n",
255269
" SELECT\n",
256270
" id, metadata, contents, embedding, {distance} as distance\n",
@@ -421,7 +435,24 @@
421435
" query = self.builder.delete_all_query()\n",
422436
" async with await self.connect() as pool:\n",
423437
" await pool.execute(query)\n",
424-
" \n",
438+
"\n",
439+
" async def delete_by_id(self, id: uuid.UUID):\n",
440+
" \"\"\"\n",
441+
" Delete records by id.\n",
442+
" \"\"\"\n",
443+
" (query, params) = self.builder.delete_by_id_query(id)\n",
444+
" async with await self.connect() as pool:\n",
445+
" return await pool.fetch(query, *params)\n",
446+
"\n",
447+
" async def delete_by_metadata(self, filter: Union[Dict[str, str], List[Dict[str, str]]]):\n",
448+
" \"\"\"\n",
449+
" Delete records by metadata filters.\n",
450+
" \"\"\"\n",
451+
" (query, params) = self.builder.delete_by_metadata_query(filter)\n",
452+
" async with await self.connect() as pool:\n",
453+
" return await pool.fetch(query, *params)\n",
454+
"\n",
455+
"\n",
425456
" async def drop_table(self):\n",
426457
" \"\"\"\n",
427458
" Drops the table\n",
@@ -723,7 +754,21 @@
723754
"except BaseException as e:\n",
724755
" pass\n",
725756
"\n",
757+
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
758+
"assert len(rec) == 2\n",
759+
"await vec.delete_by_id(rec[0][SEARCH_RESULT_ID_IDX])\n",
760+
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
761+
"assert len(rec) == 1\n",
762+
"await vec.delete_by_metadata([{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
763+
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
764+
"assert len(rec) == 0\n",
765+
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n",
766+
"assert len(rec) == 4\n",
767+
"await vec.delete_by_metadata([{\"key2\":\"val\"}])\n",
768+
"rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n",
769+
"assert len(rec) == 0\n",
726770
"\n",
771+
"assert not await vec.table_is_empty()\n",
727772
"await vec.delete_all()\n",
728773
"assert await vec.table_is_empty()\n",
729774
"\n",
@@ -889,6 +934,26 @@
889934
" with self.connect() as conn:\n",
890935
" with conn.cursor() as cur:\n",
891936
" cur.execute(query)\n",
937+
" \n",
938+
" def delete_by_id(self, id: uuid.UUID):\n",
939+
" \"\"\"\n",
940+
" Delete records by id.\n",
941+
" \"\"\"\n",
942+
" (query, params) = self.builder.delete_by_id_query(id)\n",
943+
" query, params = self._translate_to_pyformat(query, params)\n",
944+
" with self.connect() as conn:\n",
945+
" with conn.cursor() as cur:\n",
946+
" cur.execute(query, params)\n",
947+
"\n",
948+
" def delete_by_metadata(self, filter: Union[Dict[str, str], List[Dict[str, str]]]):\n",
949+
" \"\"\"\n",
950+
" Delete records by metadata filters.\n",
951+
" \"\"\"\n",
952+
" (query, params) = self.builder.delete_by_metadata_query(filter)\n",
953+
" query, params = self._translate_to_pyformat(query, params)\n",
954+
" with self.connect() as conn:\n",
955+
" with conn.cursor() as cur:\n",
956+
" cur.execute(query, params)\n",
892957
"\n",
893958
" def drop_table(self):\n",
894959
" \"\"\"\n",
@@ -1224,6 +1289,21 @@
12241289
"assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
12251290
"assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
12261291
"\n",
1292+
"rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
1293+
"len(rec) == 2\n",
1294+
"vec.delete_by_id(rec[0][SEARCH_RESULT_ID_IDX])\n",
1295+
"rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
1296+
"assert len(rec) == 1\n",
1297+
"vec.delete_by_metadata([{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
1298+
"rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n",
1299+
"assert len(rec) == 0\n",
1300+
"rec = vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n",
1301+
"assert len(rec) == 4\n",
1302+
"vec.delete_by_metadata([{\"key2\":\"val\"}])\n",
1303+
"rec = vec.search([1.0, 2.0], k=4, filter=[{\"key2\":\"val\"}])\n",
1304+
"len(rec) == 0\n",
1305+
"\n",
1306+
"assert not vec.table_is_empty()\n",
12271307
"vec.delete_all()\n",
12281308
"assert vec.table_is_empty()\n",
12291309
"\n",

0 commit comments

Comments
 (0)