|
100 | 100 | " self,\n",
|
101 | 101 | " table_name: str,\n",
|
102 | 102 | " num_dimensions: int,\n",
|
103 |
| - " distance_type: str = 'cosine') -> None:\n", |
| 103 | + " distance_type: str,\n", |
| 104 | + " id_type: str) -> None:\n", |
104 | 105 | " \"\"\"\n",
|
105 | 106 | " Initializes a base Vector object to generate queries for vector clients.\n",
|
106 | 107 | "\n",
|
|
118 | 119 | " else:\n",
|
119 | 120 | " raise ValueError(f\"unrecognized distance_type {distance_type}\")\n",
|
120 | 121 | "\n",
|
| 122 | + " if id_type.lower() != 'uuid' and id_type.lower() != 'text':\n", |
| 123 | + " raise ValueError(f\"unrecognized id_type {id_type}\")\n", |
| 124 | + "\n", |
| 125 | + " self.id_type = id_type.lower()\n", |
| 126 | + "\n", |
121 | 127 | " def _quote_ident(self, ident):\n",
|
122 | 128 | " \"\"\"\n",
|
123 | 129 | " Quotes an identifier to prevent SQL injection.\n",
|
|
170 | 176 | "CREATE EXTENSION IF NOT EXISTS vector;\n",
|
171 | 177 | "\n",
|
172 | 178 | "CREATE TABLE IF NOT EXISTS {table_name} (\n",
|
173 |
| - " id UUID PRIMARY KEY,\n", |
| 179 | + " id {id_type} PRIMARY KEY,\n", |
174 | 180 | " metadata JSONB,\n",
|
175 | 181 | " contents TEXT,\n",
|
176 | 182 | " embedding VECTOR({dimensions})\n",
|
177 | 183 | ");\n",
|
178 | 184 | "\n",
|
179 | 185 | "CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} USING GIN(metadata jsonb_path_ops);\n",
|
180 |
| - "'''.format(table_name=self._quote_ident(self.table_name), index_name=self._quote_ident(self.table_name+\"_meta_idx\"), dimensions=self.num_dimensions)\n", |
| 186 | + "'''.format(table_name=self._quote_ident(self.table_name), id_type=self.id_type, index_name=self._quote_ident(self.table_name+\"_meta_idx\"), dimensions=self.num_dimensions)\n", |
181 | 187 | "\n",
|
182 | 188 | " def _get_embedding_index_name(self):\n",
|
183 | 189 | " return self._quote_ident(self.table_name+\"_embedding_idx\")\n",
|
|
188 | 194 | " def delete_all_query(self):\n",
|
189 | 195 | " return \"TRUNCATE {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
|
190 | 196 | "\n",
|
191 |
| - " def delete_by_ids_query(self, id: List[uuid.UUID]) -> Tuple[str, List]:\n", |
192 |
| - " query = \"DELETE FROM {table_name} WHERE id = ANY($1::uuid[]);\".format(\n", |
193 |
| - " table_name=self._quote_ident(self.table_name))\n", |
194 |
| - " return (query, [id])\n", |
| 197 | + " def delete_by_ids_query(self, ids: Union[List[uuid.UUID], List[str]]) -> Tuple[str, List]:\n", |
| 198 | + " query = \"DELETE FROM {table_name} WHERE id = ANY($1::{id_type}[]);\".format(\n", |
| 199 | + " table_name=self._quote_ident(self.table_name), id_type=self.id_type)\n", |
| 200 | + " return (query, [ids])\n", |
195 | 201 | "\n",
|
196 | 202 | " def delete_by_metadata_query(self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n",
|
197 | 203 | " params: List[Any] = []\n",
|
|
355 | 361 | " service_url: str,\n",
|
356 | 362 | " table_name: str,\n",
|
357 | 363 | " num_dimensions: int,\n",
|
358 |
| - " distance_type: str = 'cosine') -> None:\n", |
| 364 | + " distance_type: str = 'cosine',\n", |
| 365 | + " id_type='UUID') -> None:\n", |
359 | 366 | " \"\"\"\n",
|
360 | 367 | " Initializes a async client for storing vector data.\n",
|
361 | 368 | "\n",
|
|
365 | 372 | " num_dimensions (int): The number of dimensions for the embedding vector.\n",
|
366 | 373 | " distance_type (str, optional): The distance type for indexing. Default is 'cosine' or '<=>'.\n",
|
367 | 374 | " \"\"\"\n",
|
368 |
| - " self.builder = QueryBuilder(table_name, num_dimensions, distance_type)\n", |
| 375 | + " self.builder = QueryBuilder(\n", |
| 376 | + " table_name, num_dimensions, distance_type, id_type)\n", |
369 | 377 | " self.service_url = service_url\n",
|
370 | 378 | " self.pool = None\n",
|
371 | 379 | "\n",
|
|
452 | 460 | " async with await self.connect() as pool:\n",
|
453 | 461 | " await pool.execute(query)\n",
|
454 | 462 | "\n",
|
455 |
| - " async def delete_by_ids(self, id: List[uuid.UUID]):\n", |
| 463 | + " async def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):\n", |
456 | 464 | " \"\"\"\n",
|
457 | 465 | " Delete records by id.\n",
|
458 | 466 | " \"\"\"\n",
|
459 |
| - " (query, params) = self.builder.delete_by_ids_query(id)\n", |
| 467 | + " (query, params) = self.builder.delete_by_ids_query(ids)\n", |
460 | 468 | " async with await self.connect() as pool:\n",
|
461 | 469 | " return await pool.fetch(query, *params)\n",
|
462 | 470 | "\n",
|
|
797 | 805 | "assert await vec.table_is_empty()\n",
|
798 | 806 | "\n",
|
799 | 807 | "await vec.drop_table()\n",
|
| 808 | + "await vec.close()\n", |
| 809 | + "\n", |
| 810 | + "vec = Async(service_url, \"data_table\", 2, id_type=\"TEXT\")\n", |
| 811 | + "await vec.create_tables()\n", |
| 812 | + "empty = await vec.table_is_empty()\n", |
| 813 | + "assert empty\n", |
| 814 | + "await vec.upsert([(\"Not a valid UUID\", {\"key\": \"val\"}, \"the brown fox\", [1.0, 1.2])])\n", |
| 815 | + "empty = await vec.table_is_empty()\n", |
| 816 | + "assert not empty\n", |
| 817 | + "await vec.delete_by_ids([\"Not a valid UUID\"])\n", |
| 818 | + "empty = await vec.table_is_empty()\n", |
| 819 | + "assert empty\n", |
| 820 | + "await vec.drop_table()\n", |
800 | 821 | "await vec.close()"
|
801 | 822 | ]
|
802 | 823 | },
|
|
838 | 859 | " service_url: str,\n",
|
839 | 860 | " table_name: str,\n",
|
840 | 861 | " num_dimensions: int,\n",
|
841 |
| - " distance_type: str = 'cosine') -> None:\n", |
842 |
| - " self.builder = QueryBuilder(table_name, num_dimensions, distance_type)\n", |
| 862 | + " distance_type: str = 'cosine',\n", |
| 863 | + " id_type='UUID') -> None:\n", |
| 864 | + " self.builder = QueryBuilder(\n", |
| 865 | + " table_name, num_dimensions, distance_type, id_type)\n", |
843 | 866 | " self.service_url = service_url\n",
|
844 | 867 | " self.pool = None\n",
|
845 | 868 | " psycopg2.extras.register_uuid()\n",
|
|
968 | 991 | " with conn.cursor() as cur:\n",
|
969 | 992 | " cur.execute(query)\n",
|
970 | 993 | "\n",
|
971 |
| - " def delete_by_ids(self, id: List[uuid.UUID]):\n", |
| 994 | + " def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):\n", |
972 | 995 | " \"\"\"\n",
|
973 | 996 | " Delete records by id.\n",
|
974 | 997 | " \"\"\"\n",
|
975 |
| - " (query, params) = self.builder.delete_by_ids_query(id)\n", |
| 998 | + " (query, params) = self.builder.delete_by_ids_query(ids)\n", |
976 | 999 | " query, params = self._translate_to_pyformat(query, params)\n",
|
977 | 1000 | " with self.connect() as conn:\n",
|
978 | 1001 | " with conn.cursor() as cur:\n",
|
|
1359 | 1382 | "assert vec.table_is_empty()\n",
|
1360 | 1383 | "\n",
|
1361 | 1384 | "vec.drop_table()\n",
|
| 1385 | + "vec.close()\n", |
1362 | 1386 | "\n",
|
| 1387 | + "vec = Sync(service_url, \"data_table\", 2, id_type=\"TEXT\")\n", |
| 1388 | + "vec.create_tables()\n", |
| 1389 | + "assert vec.table_is_empty()\n", |
| 1390 | + "vec.upsert([(\"Not a valid UUID\", {\"key\": \"val\"}, \"the brown fox\", [1.0, 1.2])])\n", |
| 1391 | + "assert not vec.table_is_empty()\n", |
| 1392 | + "vec.delete_by_ids([\"Not a valid UUID\"])\n", |
| 1393 | + "assert vec.table_is_empty()\n", |
| 1394 | + "vec.drop_table()\n", |
1363 | 1395 | "vec.close()"
|
1364 | 1396 | ]
|
1365 | 1397 | },
|
|
0 commit comments