Skip to content

Commit 5a17cfe

Browse files
committed
Add metadata predicate functionality
1 parent 6053c16 commit 5a17cfe

File tree

3 files changed

+294
-21
lines changed

3 files changed

+294
-21
lines changed

nbs/00_vector.ipynb

Lines changed: 161 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,103 @@
8888
"SEARCH_RESULT_DISTANCE_IDX = 4"
8989
]
9090
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"#| export\n",
98+
"\n",
99+
"class Predicates:\n",
100+
" logical_operators = {\n",
101+
" \"AND\": \"AND\",\n",
102+
" \"OR\": \"OR\",\n",
103+
" \"NOT\": \"NOT\",\n",
104+
" }\n",
105+
"\n",
106+
" operators_mapping = {\n",
107+
" \"=\": \"=\",\n",
108+
" \"==\": \"=\",\n",
109+
" \">=\": \">=\",\n",
110+
" \">\": \">\",\n",
111+
" \"<=\": \"<=\",\n",
112+
" \"<\": \"<\",\n",
113+
" \"!=\": \"<>\",\n",
114+
" }\n",
115+
"\n",
116+
" def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):\n",
117+
" if operator not in self.logical_operators: \n",
118+
" raise ValueError(f\"invalid operator: {operator}\")\n",
119+
" self.operator = operator\n",
120+
" self.clauses = list(clauses)\n",
121+
"\n",
122+
" def add_clause(self, clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):\n",
123+
" self.clauses.append(clause)\n",
124+
"\n",
125+
" def add_clauses(self, clauses_list: List[Union['Predicates', Tuple[str, str], Tuple[str, str, str]]]):\n",
126+
" self.clauses.extend(clauses_list)\n",
127+
" \n",
128+
" def __and__(self, other):\n",
129+
" new_predicates = Predicates(self, other, operator='AND')\n",
130+
" return new_predicates\n",
131+
"\n",
132+
" def __or__(self, other):\n",
133+
" new_predicates = Predicates(self, other, operator='OR')\n",
134+
" return new_predicates\n",
135+
"\n",
136+
" def __invert__(self):\n",
137+
" new_predicates = Predicates(self, operator='NOT')\n",
138+
" return new_predicates\n",
139+
"\n",
140+
" def __repr__(self):\n",
141+
" if self.operator:\n",
142+
" return f\"{self.operator}({', '.join(repr(clause) for clause in self.clauses)})\"\n",
143+
" else:\n",
144+
" return repr(self.clauses)\n",
145+
"\n",
146+
" def build_query(self, params: List) -> Tuple[str, List]:\n",
147+
" if not self.clauses:\n",
148+
" return \"\", []\n",
149+
"\n",
150+
" where_conditions = [] \n",
151+
"\n",
152+
" for clause in self.clauses:\n",
153+
" if isinstance(clause, Predicates):\n",
154+
" child_where_clause, params = clause.build_query(params)\n",
155+
" where_conditions.append(f\"({child_where_clause})\")\n",
156+
" elif isinstance(clause, tuple):\n",
157+
" if len(clause) == 2:\n",
158+
" field, value = clause\n",
159+
" operator = \"=\" # Default operator\n",
160+
" elif len(clause) == 3:\n",
161+
" field, operator, value = clause\n",
162+
" if operator not in self.operators_mapping:\n",
163+
" raise ValueError(f\"Invalid operator: {operator}\") \n",
164+
" operator = self.operators_mapping[operator]\n",
165+
" else:\n",
166+
" raise ValueError(\"Invalid clause format\")\n",
167+
" \n",
168+
" field_cast = ''\n",
169+
" if isinstance(value, int):\n",
170+
" field_cast = '::int'\n",
171+
" elif isinstance(value, float):\n",
172+
" field_cast = '::numeric' \n",
173+
"\n",
174+
" index = len(params)+1\n",
175+
" param_name = f\"${index}\"\n",
176+
" where_conditions.append(f\"(metadata->>'{field}'){field_cast} {operator} {param_name}\")\n",
177+
" params.append(value) \n",
178+
"\n",
179+
" if self.operator == 'NOT':\n",
180+
" or_clauses = (\" OR \").join(where_conditions)\n",
181+
" #use IS DISTINCT FROM to treat all-null clauses as False and pass the filter\n",
182+
" where_clause = f\"TRUE IS DISTINCT FROM ({or_clauses})\"\n",
183+
" else:\n",
184+
" where_clause = (\" \"+self.operator+\" \").join(where_conditions) \n",
185+
" return where_clause, params"
186+
]
187+
},
91188
{
92189
"cell_type": "code",
93190
"execution_count": null,
@@ -260,7 +357,7 @@
260357
"\n",
261358
" return (where, params)\n",
262359
"\n",
263-
" def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
360+
" def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, predicates: Optional[Predicates] = None) -> Tuple[str, List]:\n",
264361
" \"\"\"\n",
265362
" Generates a similarity query.\n",
266363
"\n",
@@ -283,7 +380,19 @@
283380
" distance = \"-1.0\"\n",
284381
" order_by_clause = \"\"\n",
285382
"\n",
286-
" (where, params) = self._where_clause_for_filter(params, filter)\n",
383+
" where_clauses = []\n",
384+
" if filter is not None:\n",
385+
" (where_filter, params) = self._where_clause_for_filter(params, filter)\n",
386+
" where_clauses.append(where_filter)\n",
387+
"\n",
388+
" if predicates is not None:\n",
389+
" (where_predicates, params) = predicates.build_query(params)\n",
390+
" where_clauses.append(where_predicates)\n",
391+
" \n",
392+
" if len(where_clauses) > 0:\n",
393+
" where = \" AND \".join(where_clauses)\n",
394+
" else:\n",
395+
" where = \"TRUE\"\n",
287396
"\n",
288397
" query = '''\n",
289398
" SELECT\n",
@@ -534,15 +643,16 @@
534643
" query_embedding: Optional[List[float]] = None,\n",
535644
" # The number of nearest neighbors to retrieve. Default is 10.\n",
536645
" limit: int = 10,\n",
537-
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.\n",
646+
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n",
647+
" predicates: Optional[Predicates] = None): # A filter for metadata. Default is None.\n",
538648
" \"\"\"\n",
539649
" Retrieves similar records using a similarity query.\n",
540650
"\n",
541651
" Returns:\n",
542652
" List: List of similar records.\n",
543653
" \"\"\"\n",
544654
" (query, params) = self.builder.search_query(\n",
545-
" query_embedding, limit, filter)\n",
655+
" query_embedding, limit, filter, predicates)\n",
546656
" async with await self.connect() as pool:\n",
547657
" return await pool.fetch(query, *params)"
548658
]
@@ -653,7 +763,7 @@
653763
"\n",
654764
"> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
655765
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
656-
"> ne)\n",
766+
"> ne, predicates:Optional[__main__.Predicates]=None)\n",
657767
"\n",
658768
"Retrieves similar records using a similarity query.\n",
659769
"\n",
@@ -669,7 +779,7 @@
669779
"\n",
670780
"> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
671781
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
672-
"> ne)\n",
782+
"> ne, predicates:Optional[__main__.Predicates]=None)\n",
673783
"\n",
674784
"Retrieves similar records using a similarity query.\n",
675785
"\n",
@@ -722,7 +832,7 @@
722832
"\n",
723833
"await vec.upsert([\n",
724834
" (uuid.uuid4(), '''{\"key\":\"val\"}''', \"the brown fox\", [1.0, 1.3]),\n",
725-
" (uuid.uuid4(), '''{\"key\":\"val2\"}''', \"the brown fox\", [1.0, 1.4]),\n",
835+
" (uuid.uuid4(), '''{\"key\":\"val2\", \"key_10\": \"10\", \"key_11\": \"11.3\"}''', \"the brown fox\", [1.0, 1.4]),\n",
726836
" (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.5]),\n",
727837
" (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.6]),\n",
728838
" (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.6]),\n",
@@ -769,6 +879,39 @@
769879
"\n",
770880
"assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
771881
"\n",
882+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n",
883+
"assert len(rec) == 1\n",
884+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"==\", \"val2\")))\n",
885+
"assert len(rec) == 1\n",
886+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)))\n",
887+
"assert len(rec) == 1\n",
888+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 10)))\n",
889+
"assert len(rec) == 0\n",
890+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10)))\n",
891+
"assert len(rec) == 1\n",
892+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10.0)))\n",
893+
"assert len(rec) == 1\n",
894+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<=\", 11.3)))\n",
895+
"assert len(rec) == 1\n",
896+
"rec = await vec.search(limit=4, predicates=Predicates((\"key_11\", \">=\", 11.29999)))\n",
897+
"assert len(rec) == 1\n",
898+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<\", 11.299999)))\n",
899+
"assert len(rec) == 0\n",
900+
"\n",
901+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(*[(\"key\", \"val2\"), (\"key_10\", \"<\", 100)]))\n",
902+
"assert len(rec) == 1\n",
903+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100), operator='AND'))\n",
904+
"assert len(rec) == 1\n",
905+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_2\", \"val_2\"), operator='OR'))\n",
906+
"assert len(rec) == 2\n",
907+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) & (Predicates((\"key\", \"val2\")) | Predicates((\"key_2\", \"val_2\")))) \n",
908+
"assert len(rec) == 1\n",
909+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) and (Predicates((\"key\", \"val2\")) or Predicates((\"key_2\", \"val_2\")))) \n",
910+
"assert len(rec) == 1\n",
911+
"rec = await vec.search(limit=4, predicates=~Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100)))\n",
912+
"assert len(rec) == 4\n",
913+
"\n",
914+
"\n",
772915
"try:\n",
773916
" # can't upsert using both keys and dictionaries\n",
774917
" await vec.upsert([\n",
@@ -1069,7 +1212,11 @@
10691212
" with conn.cursor() as cur:\n",
10701213
" cur.execute(query)\n",
10711214
"\n",
1072-
" def search(self, query_embedding: Optional[List[float]] = None, limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n",
1215+
" def search(self, \n",
1216+
" query_embedding: Optional[List[float]] = None, \n",
1217+
" limit: int = 10, \n",
1218+
" filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n",
1219+
" predicates: Optional[Predicates] = None):\n",
10731220
" \"\"\"\n",
10741221
" Retrieves similar records using a similarity query.\n",
10751222
"\n",
@@ -1087,7 +1234,7 @@
10871234
" query_embedding_np = None\n",
10881235
"\n",
10891236
" (query, params) = self.builder.search_query(\n",
1090-
" query_embedding_np, limit, filter)\n",
1237+
" query_embedding_np, limit, filter, predicates)\n",
10911238
" query, params = self._translate_to_pyformat(query, params)\n",
10921239
" with self.connect() as conn:\n",
10931240
" with conn.cursor() as cur:\n",
@@ -1207,7 +1354,7 @@
12071354
"\n",
12081355
"> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
12091356
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
1210-
"> e)\n",
1357+
"> e, predicates:Optional[__main__.Predicates]=None)\n",
12111358
"\n",
12121359
"Retrieves similar records using a similarity query.\n",
12131360
"\n",
@@ -1228,7 +1375,7 @@
12281375
"\n",
12291376
"> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
12301377
"> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
1231-
"> e)\n",
1378+
"> e, predicates:Optional[__main__.Predicates]=None)\n",
12321379
"\n",
12331380
"Retrieves similar records using a similarity query.\n",
12341381
"\n",
@@ -1363,6 +1510,9 @@
13631510
"assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
13641511
"assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
13651512
"\n",
1513+
"rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n",
1514+
"assert len(rec) == 1\n",
1515+
"\n",
13661516
"rec = vec.search([1.0, 2.0], limit=4, filter=[\n",
13671517
" {\"key_1\": \"val_1\"}, {\"key2\": \"val2\"}])\n",
13681518
"len(rec) == 2\n",

timescale_vector/_modidx.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@
3333
'timescale_vector.client.Async.table_is_empty': ( 'vector.html#async.table_is_empty',
3434
'timescale_vector/client.py'),
3535
'timescale_vector.client.Async.upsert': ('vector.html#async.upsert', 'timescale_vector/client.py'),
36+
'timescale_vector.client.Predicates': ('vector.html#predicates', 'timescale_vector/client.py'),
37+
'timescale_vector.client.Predicates.__and__': ( 'vector.html#predicates.__and__',
38+
'timescale_vector/client.py'),
39+
'timescale_vector.client.Predicates.__init__': ( 'vector.html#predicates.__init__',
40+
'timescale_vector/client.py'),
41+
'timescale_vector.client.Predicates.__invert__': ( 'vector.html#predicates.__invert__',
42+
'timescale_vector/client.py'),
43+
'timescale_vector.client.Predicates.__or__': ( 'vector.html#predicates.__or__',
44+
'timescale_vector/client.py'),
45+
'timescale_vector.client.Predicates.__repr__': ( 'vector.html#predicates.__repr__',
46+
'timescale_vector/client.py'),
47+
'timescale_vector.client.Predicates.add_clause': ( 'vector.html#predicates.add_clause',
48+
'timescale_vector/client.py'),
49+
'timescale_vector.client.Predicates.add_clauses': ( 'vector.html#predicates.add_clauses',
50+
'timescale_vector/client.py'),
51+
'timescale_vector.client.Predicates.build_query': ( 'vector.html#predicates.build_query',
52+
'timescale_vector/client.py'),
3653
'timescale_vector.client.QueryBuilder': ('vector.html#querybuilder', 'timescale_vector/client.py'),
3754
'timescale_vector.client.QueryBuilder.__init__': ( 'vector.html#querybuilder.__init__',
3855
'timescale_vector/client.py'),

0 commit comments

Comments
 (0)