Skip to content

Commit 443961b

Browse files
committed
Improve Predicate creation interface
Previously Predicates had to be created from tuples: Predicates(("key", "==", "val)) But, as you can see above the single-predicate case is weird because of the extra paranthesis. This change allows for: Predicates("key", "==", "val) That is, for the single-predicate case it accepts 3 arguments instead of a tuple.
1 parent 3d12b32 commit 443961b

File tree

2 files changed

+66
-52
lines changed

2 files changed

+66
-52
lines changed

nbs/00_vector.ipynb

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
15+
"execution_count": 97,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": null,
24+
"execution_count": 98,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -31,7 +31,7 @@
3131
},
3232
{
3333
"cell_type": "code",
34-
"execution_count": null,
34+
"execution_count": 99,
3535
"metadata": {},
3636
"outputs": [],
3737
"source": [
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"cell_type": "code",
45-
"execution_count": null,
45+
"execution_count": 100,
4646
"metadata": {},
4747
"outputs": [],
4848
"source": [
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": null,
55+
"execution_count": 101,
5656
"metadata": {},
5757
"outputs": [],
5858
"source": [
@@ -73,7 +73,7 @@
7373
},
7474
{
7575
"cell_type": "code",
76-
"execution_count": null,
76+
"execution_count": 102,
7777
"metadata": {},
7878
"outputs": [],
7979
"source": [
@@ -146,7 +146,7 @@
146146
},
147147
{
148148
"cell_type": "code",
149-
"execution_count": null,
149+
"execution_count": 103,
150150
"metadata": {},
151151
"outputs": [],
152152
"source": [
@@ -275,7 +275,7 @@
275275
},
276276
{
277277
"cell_type": "code",
278-
"execution_count": null,
278+
"execution_count": 104,
279279
"metadata": {},
280280
"outputs": [],
281281
"source": [
@@ -290,7 +290,7 @@
290290
},
291291
{
292292
"cell_type": "code",
293-
"execution_count": null,
293+
"execution_count": 105,
294294
"metadata": {},
295295
"outputs": [],
296296
"source": [
@@ -388,7 +388,7 @@
388388
},
389389
{
390390
"cell_type": "code",
391-
"execution_count": null,
391+
"execution_count": 106,
392392
"metadata": {},
393393
"outputs": [],
394394
"source": [
@@ -411,7 +411,9 @@
411411
" \"!=\": \"<>\",\n",
412412
" }\n",
413413
"\n",
414-
" def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):\n",
414+
" PredicateValue = Union[str, int, float]\n",
415+
"\n",
416+
" def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'):\n",
415417
" \"\"\"\n",
416418
" Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~).\n",
417419
"\n",
@@ -425,9 +427,14 @@
425427
" if operator not in self.logical_operators: \n",
426428
" raise ValueError(f\"invalid operator: {operator}\")\n",
427429
" self.operator = operator\n",
428-
" self.clauses = list(clauses)\n",
430+
" if isinstance(clauses[0], str):\n",
431+
" if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):\n",
432+
" raise ValueError(f\"Invalid clause format: {clauses}\")\n",
433+
" self.clauses = [(clauses[0], clauses[1], clauses[2])]\n",
434+
" else:\n",
435+
" self.clauses = list(clauses)\n",
429436
"\n",
430-
" def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):\n",
437+
" def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]):\n",
431438
" \"\"\"\n",
432439
" Add a clause to the predicates object.\n",
433440
"\n",
@@ -436,7 +443,12 @@
436443
" clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str]\n",
437444
" Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value).\n",
438445
" \"\"\"\n",
439-
" self.clauses.extend(list(clause))\n",
446+
" if isinstance(clause[0], str):\n",
447+
" if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):\n",
448+
" raise ValueError(f\"Invalid clause format: {clause}\")\n",
449+
" self.clauses.append((clause[0], clause[1], clause[2]))\n",
450+
" else:\n",
451+
" self.clauses.extend(list(clause))\n",
440452
" \n",
441453
" def __and__(self, other):\n",
442454
" new_predicates = Predicates(self, other, operator='AND')\n",
@@ -522,7 +534,7 @@
522534
},
523535
{
524536
"cell_type": "code",
525-
"execution_count": null,
537+
"execution_count": 107,
526538
"metadata": {},
527539
"outputs": [],
528540
"source": [
@@ -836,7 +848,7 @@
836848
},
837849
{
838850
"cell_type": "code",
839-
"execution_count": null,
851+
"execution_count": 108,
840852
"metadata": {},
841853
"outputs": [
842854
{
@@ -864,7 +876,7 @@
864876
"Generates a query to create the tables, indexes, and extensions needed to store the vector data."
865877
]
866878
},
867-
"execution_count": null,
879+
"execution_count": 108,
868880
"metadata": {},
869881
"output_type": "execute_result"
870882
}
@@ -883,7 +895,7 @@
883895
},
884896
{
885897
"cell_type": "code",
886-
"execution_count": null,
898+
"execution_count": 109,
887899
"metadata": {},
888900
"outputs": [],
889901
"source": [
@@ -1143,7 +1155,7 @@
11431155
},
11441156
{
11451157
"cell_type": "code",
1146-
"execution_count": null,
1158+
"execution_count": 110,
11471159
"metadata": {},
11481160
"outputs": [
11491161
{
@@ -1171,7 +1183,7 @@
11711183
"Creates necessary tables."
11721184
]
11731185
},
1174-
"execution_count": null,
1186+
"execution_count": 110,
11751187
"metadata": {},
11761188
"output_type": "execute_result"
11771189
}
@@ -1182,7 +1194,7 @@
11821194
},
11831195
{
11841196
"cell_type": "code",
1185-
"execution_count": null,
1197+
"execution_count": 111,
11861198
"metadata": {},
11871199
"outputs": [
11881200
{
@@ -1210,7 +1222,7 @@
12101222
"Creates necessary tables."
12111223
]
12121224
},
1213-
"execution_count": null,
1225+
"execution_count": 111,
12141226
"metadata": {},
12151227
"output_type": "execute_result"
12161228
}
@@ -1221,21 +1233,9 @@
12211233
},
12221234
{
12231235
"cell_type": "code",
1224-
"execution_count": null,
1236+
"execution_count": 112,
12251237
"metadata": {},
12261238
"outputs": [
1227-
{
1228-
"name": "stderr",
1229-
"output_type": "stream",
1230-
"text": [
1231-
"/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n",
1232-
"Returns \n",
1233-
"-------- in \n",
1234-
"Retrieves similar records using a similarity query.\n",
1235-
"...\n",
1236-
" else: warn(msg)\n"
1237-
]
1238-
},
12391239
{
12401240
"data": {
12411241
"text/markdown": [
@@ -1285,7 +1285,7 @@
12851285
"| **Returns** | **List: List of similar records.** | | |"
12861286
]
12871287
},
1288-
"execution_count": null,
1288+
"execution_count": 112,
12891289
"metadata": {},
12901290
"output_type": "execute_result"
12911291
}
@@ -1296,7 +1296,7 @@
12961296
},
12971297
{
12981298
"cell_type": "code",
1299-
"execution_count": null,
1299+
"execution_count": 117,
13001300
"metadata": {},
13011301
"outputs": [],
13021302
"source": [
@@ -1317,7 +1317,7 @@
13171317
},
13181318
{
13191319
"cell_type": "code",
1320-
"execution_count": null,
1320+
"execution_count": 118,
13211321
"metadata": {},
13221322
"outputs": [],
13231323
"source": [
@@ -1393,19 +1393,21 @@
13931393
"assert len(rec) == 1\n",
13941394
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"==\", \"val2\")))\n",
13951395
"assert len(rec) == 1\n",
1396-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)))\n",
1396+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\", \"==\", \"val2\"))\n",
1397+
"assert len(rec) == 1\n",
1398+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100))\n",
13971399
"assert len(rec) == 1\n",
1398-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 10)))\n",
1400+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 10))\n",
13991401
"assert len(rec) == 0\n",
1400-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10)))\n",
1402+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10))\n",
14011403
"assert len(rec) == 1\n",
1402-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10.0)))\n",
1404+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10.0))\n",
14031405
"assert len(rec) == 1\n",
1404-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<=\", 11.3)))\n",
1406+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<=\", 11.3))\n",
14051407
"assert len(rec) == 1\n",
1406-
"rec = await vec.search(limit=4, predicates=Predicates((\"key_11\", \">=\", 11.29999)))\n",
1408+
"rec = await vec.search(limit=4, predicates=Predicates(\"key_11\", \">=\", 11.29999))\n",
14071409
"assert len(rec) == 1\n",
1408-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<\", 11.299999)))\n",
1410+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<\", 11.299999))\n",
14091411
"assert len(rec) == 0\n",
14101412
"\n",
14111413
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(*[(\"key\", \"val2\"), (\"key_10\", \"<\", 100)]))\n",
@@ -1414,9 +1416,9 @@
14141416
"assert len(rec) == 1\n",
14151417
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_2\", \"val_2\"), operator='OR'))\n",
14161418
"assert len(rec) == 2\n",
1417-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) & (Predicates((\"key\", \"val2\")) | Predicates((\"key_2\", \"val_2\")))) \n",
1419+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) & (Predicates(\"key\",\"==\", \"val2\",) | Predicates(\"key_2\", \"==\", \"val_2\"))) \n",
14181420
"assert len(rec) == 1\n",
1419-
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) and (Predicates((\"key\", \"val2\")) or Predicates((\"key_2\", \"val_2\")))) \n",
1421+
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) and (Predicates(\"key\",\"==\", \"val2\") or Predicates(\"key_2\",\"==\", \"val_2\"))) \n",
14201422
"assert len(rec) == 1\n",
14211423
"rec = await vec.search(limit=4, predicates=~Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100)))\n",
14221424
"assert len(rec) == 4\n",
@@ -2193,7 +2195,7 @@
21932195
"assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
21942196
"assert rec[0][\"distance\"] == 0.0009438353921149556\n",
21952197
"\n",
2196-
"rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n",
2198+
"rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\",\"==\", \"val2\"))\n",
21972199
"assert len(rec) == 1\n",
21982200
"\n",
21992201
"rec = vec.search([1.0, 2.0], limit=4, filter=[\n",

timescale_vector/client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ class Predicates:
307307
"!=": "<>",
308308
}
309309

310-
def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):
310+
PredicateValue = Union[str, int, float]
311+
312+
def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'):
311313
"""
312314
Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~).
313315
@@ -321,9 +323,14 @@ def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str
321323
if operator not in self.logical_operators:
322324
raise ValueError(f"invalid operator: {operator}")
323325
self.operator = operator
324-
self.clauses = list(clauses)
326+
if isinstance(clauses[0], str):
327+
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
328+
raise ValueError("Invalid clause format: {clauses}")
329+
self.clauses = [(clauses[0], clauses[1], clauses[2])]
330+
else:
331+
self.clauses = list(clauses)
325332

326-
def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):
333+
def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]):
327334
"""
328335
Add a clause to the predicates object.
329336
@@ -332,7 +339,12 @@ def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, st
332339
clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str]
333340
Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value).
334341
"""
335-
self.clauses.extend(list(clause))
342+
if isinstance(clause[0], str):
343+
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
344+
raise ValueError("Invalid clause format: {clauses}")
345+
self.clauses.append((clause[0], clause[1], clause[2]))
346+
else:
347+
self.clauses.extend(list(clause))
336348

337349
def __and__(self, other):
338350
new_predicates = Predicates(self, other, operator='AND')

0 commit comments

Comments
 (0)