Skip to content

Commit d839964

Browse files
committed
Allow IDs to be strings
1 parent 4e46620 commit d839964

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

nbs/00_vector.ipynb

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
" self,\n",
101101
" table_name: str,\n",
102102
" num_dimensions: int,\n",
103-
" distance_type: str = 'cosine') -> None:\n",
103+
" distance_type: str,\n",
104+
" id_type: str) -> None:\n",
104105
" \"\"\"\n",
105106
" Initializes a base Vector object to generate queries for vector clients.\n",
106107
"\n",
@@ -118,6 +119,11 @@
118119
" else:\n",
119120
" raise ValueError(f\"unrecognized distance_type {distance_type}\")\n",
120121
"\n",
122+
" if id_type.lower() != 'uuid' and id_type.lower() != 'text':\n",
123+
" raise ValueError(f\"unrecognized id_type {id_type}\")\n",
124+
"\n",
125+
" self.id_type = id_type.lower()\n",
126+
"\n",
121127
" def _quote_ident(self, ident):\n",
122128
" \"\"\"\n",
123129
" Quotes an identifier to prevent SQL injection.\n",
@@ -170,14 +176,14 @@
170176
"CREATE EXTENSION IF NOT EXISTS vector;\n",
171177
"\n",
172178
"CREATE TABLE IF NOT EXISTS {table_name} (\n",
173-
" id UUID PRIMARY KEY,\n",
179+
" id {id_type} PRIMARY KEY,\n",
174180
" metadata JSONB,\n",
175181
" contents TEXT,\n",
176182
" embedding VECTOR({dimensions})\n",
177183
");\n",
178184
"\n",
179185
"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} USING GIN(metadata jsonb_path_ops);\n",
180-
"'''.format(table_name=self._quote_ident(self.table_name), index_name=self._quote_ident(self.table_name+\"_meta_idx\"), dimensions=self.num_dimensions)\n",
186+
"'''.format(table_name=self._quote_ident(self.table_name), id_type=self.id_type, index_name=self._quote_ident(self.table_name+\"_meta_idx\"), dimensions=self.num_dimensions)\n",
181187
"\n",
182188
" def _get_embedding_index_name(self):\n",
183189
" return self._quote_ident(self.table_name+\"_embedding_idx\")\n",
@@ -188,10 +194,10 @@
188194
" def delete_all_query(self):\n",
189195
" return \"TRUNCATE {table_name};\".format(table_name=self._quote_ident(self.table_name))\n",
190196
"\n",
191-
" def delete_by_ids_query(self, id: List[uuid.UUID]) -> Tuple[str, List]:\n",
192-
" query = \"DELETE FROM {table_name} WHERE id = ANY($1::uuid[]);\".format(\n",
193-
" table_name=self._quote_ident(self.table_name))\n",
194-
" return (query, [id])\n",
197+
" def delete_by_ids_query(self, ids: Union[List[uuid.UUID], List[str]]) -> Tuple[str, List]:\n",
198+
" query = \"DELETE FROM {table_name} WHERE id = ANY($1::{id_type}[]);\".format(\n",
199+
" table_name=self._quote_ident(self.table_name), id_type=self.id_type)\n",
200+
" return (query, [ids])\n",
195201
"\n",
196202
" def delete_by_metadata_query(self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n",
197203
" params: List[Any] = []\n",
@@ -355,7 +361,8 @@
355361
" service_url: str,\n",
356362
" table_name: str,\n",
357363
" num_dimensions: int,\n",
358-
" distance_type: str = 'cosine') -> None:\n",
364+
" distance_type: str = 'cosine',\n",
365+
" id_type='UUID') -> None:\n",
359366
" \"\"\"\n",
360367
" Initializes a async client for storing vector data.\n",
361368
"\n",
@@ -365,7 +372,8 @@
365372
" num_dimensions (int): The number of dimensions for the embedding vector.\n",
366373
" distance_type (str, optional): The distance type for indexing. Default is 'cosine' or '<=>'.\n",
367374
" \"\"\"\n",
368-
" self.builder = QueryBuilder(table_name, num_dimensions, distance_type)\n",
375+
" self.builder = QueryBuilder(\n",
376+
" table_name, num_dimensions, distance_type, id_type)\n",
369377
" self.service_url = service_url\n",
370378
" self.pool = None\n",
371379
"\n",
@@ -452,11 +460,11 @@
452460
" async with await self.connect() as pool:\n",
453461
" await pool.execute(query)\n",
454462
"\n",
455-
" async def delete_by_ids(self, id: List[uuid.UUID]):\n",
463+
" async def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):\n",
456464
" \"\"\"\n",
457465
" Delete records by id.\n",
458466
" \"\"\"\n",
459-
" (query, params) = self.builder.delete_by_ids_query(id)\n",
467+
" (query, params) = self.builder.delete_by_ids_query(ids)\n",
460468
" async with await self.connect() as pool:\n",
461469
" return await pool.fetch(query, *params)\n",
462470
"\n",
@@ -797,6 +805,19 @@
797805
"assert await vec.table_is_empty()\n",
798806
"\n",
799807
"await vec.drop_table()\n",
808+
"await vec.close()\n",
809+
"\n",
810+
"vec = Async(service_url, \"data_table\", 2, id_type=\"TEXT\")\n",
811+
"await vec.create_tables()\n",
812+
"empty = await vec.table_is_empty()\n",
813+
"assert empty\n",
814+
"await vec.upsert([(\"Not a valid UUID\", {\"key\": \"val\"}, \"the brown fox\", [1.0, 1.2])])\n",
815+
"empty = await vec.table_is_empty()\n",
816+
"assert not empty\n",
817+
"await vec.delete_by_ids([\"Not a valid UUID\"])\n",
818+
"empty = await vec.table_is_empty()\n",
819+
"assert empty\n",
820+
"await vec.drop_table()\n",
800821
"await vec.close()"
801822
]
802823
},
@@ -838,8 +859,10 @@
838859
" service_url: str,\n",
839860
" table_name: str,\n",
840861
" num_dimensions: int,\n",
841-
" distance_type: str = 'cosine') -> None:\n",
842-
" self.builder = QueryBuilder(table_name, num_dimensions, distance_type)\n",
862+
" distance_type: str = 'cosine',\n",
863+
" id_type='UUID') -> None:\n",
864+
" self.builder = QueryBuilder(\n",
865+
" table_name, num_dimensions, distance_type, id_type)\n",
843866
" self.service_url = service_url\n",
844867
" self.pool = None\n",
845868
" psycopg2.extras.register_uuid()\n",
@@ -968,11 +991,11 @@
968991
" with conn.cursor() as cur:\n",
969992
" cur.execute(query)\n",
970993
"\n",
971-
" def delete_by_ids(self, id: List[uuid.UUID]):\n",
994+
" def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):\n",
972995
" \"\"\"\n",
973996
" Delete records by id.\n",
974997
" \"\"\"\n",
975-
" (query, params) = self.builder.delete_by_ids_query(id)\n",
998+
" (query, params) = self.builder.delete_by_ids_query(ids)\n",
976999
" query, params = self._translate_to_pyformat(query, params)\n",
9771000
" with self.connect() as conn:\n",
9781001
" with conn.cursor() as cur:\n",
@@ -1359,7 +1382,16 @@
13591382
"assert vec.table_is_empty()\n",
13601383
"\n",
13611384
"vec.drop_table()\n",
1385+
"vec.close()\n",
13621386
"\n",
1387+
"vec = Sync(service_url, \"data_table\", 2, id_type=\"TEXT\")\n",
1388+
"vec.create_tables()\n",
1389+
"assert vec.table_is_empty()\n",
1390+
"vec.upsert([(\"Not a valid UUID\", {\"key\": \"val\"}, \"the brown fox\", [1.0, 1.2])])\n",
1391+
"assert not vec.table_is_empty()\n",
1392+
"vec.delete_by_ids([\"Not a valid UUID\"])\n",
1393+
"assert vec.table_is_empty()\n",
1394+
"vec.drop_table()\n",
13631395
"vec.close()"
13641396
]
13651397
},

timescale_vector/client.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def __init__(
2525
self,
2626
table_name: str,
2727
num_dimensions: int,
28-
distance_type: str = 'cosine') -> None:
28+
distance_type: str,
29+
id_type: str) -> None:
2930
"""
3031
Initializes a base Vector object to generate queries for vector clients.
3132
@@ -43,6 +44,11 @@ def __init__(
4344
else:
4445
raise ValueError(f"unrecognized distance_type {distance_type}")
4546

47+
if id_type.lower() != 'uuid' and id_type.lower() != 'text':
48+
raise ValueError(f"unrecognized id_type {id_type}")
49+
50+
self.id_type = id_type.lower()
51+
4652
def _quote_ident(self, ident):
4753
"""
4854
Quotes an identifier to prevent SQL injection.
@@ -95,14 +101,14 @@ def get_create_query(self):
95101
CREATE EXTENSION IF NOT EXISTS vector;
96102
97103
CREATE TABLE IF NOT EXISTS {table_name} (
98-
id UUID PRIMARY KEY,
104+
id {id_type} PRIMARY KEY,
99105
metadata JSONB,
100106
contents TEXT,
101107
embedding VECTOR({dimensions})
102108
);
103109
104110
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} USING GIN(metadata jsonb_path_ops);
105-
'''.format(table_name=self._quote_ident(self.table_name), index_name=self._quote_ident(self.table_name+"_meta_idx"), dimensions=self.num_dimensions)
111+
'''.format(table_name=self._quote_ident(self.table_name), id_type=self.id_type, index_name=self._quote_ident(self.table_name+"_meta_idx"), dimensions=self.num_dimensions)
106112

107113
def _get_embedding_index_name(self):
108114
return self._quote_ident(self.table_name+"_embedding_idx")
@@ -113,10 +119,10 @@ def drop_embedding_index_query(self):
113119
def delete_all_query(self):
114120
return "TRUNCATE {table_name};".format(table_name=self._quote_ident(self.table_name))
115121

116-
def delete_by_ids_query(self, id: List[uuid.UUID]) -> Tuple[str, List]:
117-
query = "DELETE FROM {table_name} WHERE id = ANY($1::uuid[]);".format(
118-
table_name=self._quote_ident(self.table_name))
119-
return (query, [id])
122+
def delete_by_ids_query(self, ids: Union[List[uuid.UUID], List[str]]) -> Tuple[str, List]:
123+
query = "DELETE FROM {table_name} WHERE id = ANY($1::{id_type}[]);".format(
124+
table_name=self._quote_ident(self.table_name), id_type=self.id_type)
125+
return (query, [ids])
120126

121127
def delete_by_metadata_query(self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:
122128
params: List[Any] = []
@@ -220,7 +226,8 @@ def __init__(
220226
service_url: str,
221227
table_name: str,
222228
num_dimensions: int,
223-
distance_type: str = 'cosine') -> None:
229+
distance_type: str = 'cosine',
230+
id_type='UUID') -> None:
224231
"""
225232
Initializes a async client for storing vector data.
226233
@@ -230,7 +237,8 @@ def __init__(
230237
num_dimensions (int): The number of dimensions for the embedding vector.
231238
distance_type (str, optional): The distance type for indexing. Default is 'cosine' or '<=>'.
232239
"""
233-
self.builder = QueryBuilder(table_name, num_dimensions, distance_type)
240+
self.builder = QueryBuilder(
241+
table_name, num_dimensions, distance_type, id_type)
234242
self.service_url = service_url
235243
self.pool = None
236244

@@ -317,11 +325,11 @@ async def delete_all(self, drop_index=True):
317325
async with await self.connect() as pool:
318326
await pool.execute(query)
319327

320-
async def delete_by_ids(self, id: List[uuid.UUID]):
328+
async def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):
321329
"""
322330
Delete records by id.
323331
"""
324-
(query, params) = self.builder.delete_by_ids_query(id)
332+
(query, params) = self.builder.delete_by_ids_query(ids)
325333
async with await self.connect() as pool:
326334
return await pool.fetch(query, *params)
327335

@@ -417,8 +425,10 @@ def __init__(
417425
service_url: str,
418426
table_name: str,
419427
num_dimensions: int,
420-
distance_type: str = 'cosine') -> None:
421-
self.builder = QueryBuilder(table_name, num_dimensions, distance_type)
428+
distance_type: str = 'cosine',
429+
id_type='UUID') -> None:
430+
self.builder = QueryBuilder(
431+
table_name, num_dimensions, distance_type, id_type)
422432
self.service_url = service_url
423433
self.pool = None
424434
psycopg2.extras.register_uuid()
@@ -547,11 +557,11 @@ def delete_all(self, drop_index=True):
547557
with conn.cursor() as cur:
548558
cur.execute(query)
549559

550-
def delete_by_ids(self, id: List[uuid.UUID]):
560+
def delete_by_ids(self, ids: Union[List[uuid.UUID], List[str]]):
551561
"""
552562
Delete records by id.
553563
"""
554-
(query, params) = self.builder.delete_by_ids_query(id)
564+
(query, params) = self.builder.delete_by_ids_query(ids)
555565
query, params = self._translate_to_pyformat(query, params)
556566
with self.connect() as conn:
557567
with conn.cursor() as cur:

0 commit comments

Comments
 (0)