|
88 | 88 | "SEARCH_RESULT_DISTANCE_IDX = 4"
|
89 | 89 | ]
|
90 | 90 | },
|
| 91 | + { |
| 92 | + "cell_type": "code", |
| 93 | + "execution_count": null, |
| 94 | + "metadata": {}, |
| 95 | + "outputs": [], |
| 96 | + "source": [ |
| 97 | + "#| export\n", |
| 98 | + "\n", |
| 99 | + "class Predicates:\n", |
| 100 | + " logical_operators = {\n", |
| 101 | + " \"AND\": \"AND\",\n", |
| 102 | + " \"OR\": \"OR\",\n", |
| 103 | + " \"NOT\": \"NOT\",\n", |
| 104 | + " }\n", |
| 105 | + "\n", |
| 106 | + " operators_mapping = {\n", |
| 107 | + " \"=\": \"=\",\n", |
| 108 | + " \"==\": \"=\",\n", |
| 109 | + " \">=\": \">=\",\n", |
| 110 | + " \">\": \">\",\n", |
| 111 | + " \"<=\": \"<=\",\n", |
| 112 | + " \"<\": \"<\",\n", |
| 113 | + " \"!=\": \"<>\",\n", |
| 114 | + " }\n", |
| 115 | + "\n", |
| 116 | + " def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):\n", |
| 117 | + " if operator not in self.logical_operators: \n", |
| 118 | + " raise ValueError(f\"invalid operator: {operator}\")\n", |
| 119 | + " self.operator = operator\n", |
| 120 | + " self.clauses = list(clauses)\n", |
| 121 | + "\n", |
| 122 | + " def add_clause(self, clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):\n", |
| 123 | + " self.clauses.append(clause)\n", |
| 124 | + "\n", |
| 125 | + " def add_clauses(self, clauses_list: List[Union['Predicates', Tuple[str, str], Tuple[str, str, str]]]):\n", |
| 126 | + " self.clauses.extend(clauses_list)\n", |
| 127 | + " \n", |
| 128 | + " def __and__(self, other):\n", |
| 129 | + " new_predicates = Predicates(self, other, operator='AND')\n", |
| 130 | + " return new_predicates\n", |
| 131 | + "\n", |
| 132 | + " def __or__(self, other):\n", |
| 133 | + " new_predicates = Predicates(self, other, operator='OR')\n", |
| 134 | + " return new_predicates\n", |
| 135 | + "\n", |
| 136 | + " def __invert__(self):\n", |
| 137 | + " new_predicates = Predicates(self, operator='NOT')\n", |
| 138 | + " return new_predicates\n", |
| 139 | + "\n", |
| 140 | + " def __repr__(self):\n", |
| 141 | + " if self.operator:\n", |
| 142 | + " return f\"{self.operator}({', '.join(repr(clause) for clause in self.clauses)})\"\n", |
| 143 | + " else:\n", |
| 144 | + " return repr(self.clauses)\n", |
| 145 | + "\n", |
| 146 | + " def build_query(self, params: List) -> Tuple[str, List]:\n", |
| 147 | + " if not self.clauses:\n", |
| 148 | + " return \"\", []\n", |
| 149 | + "\n", |
| 150 | + " where_conditions = [] \n", |
| 151 | + "\n", |
| 152 | + " for clause in self.clauses:\n", |
| 153 | + " if isinstance(clause, Predicates):\n", |
| 154 | + " child_where_clause, params = clause.build_query(params)\n", |
| 155 | + " where_conditions.append(f\"({child_where_clause})\")\n", |
| 156 | + " elif isinstance(clause, tuple):\n", |
| 157 | + " if len(clause) == 2:\n", |
| 158 | + " field, value = clause\n", |
| 159 | + " operator = \"=\" # Default operator\n", |
| 160 | + " elif len(clause) == 3:\n", |
| 161 | + " field, operator, value = clause\n", |
| 162 | + " if operator not in self.operators_mapping:\n", |
| 163 | + " raise ValueError(f\"Invalid operator: {operator}\") \n", |
| 164 | + " operator = self.operators_mapping[operator]\n", |
| 165 | + " else:\n", |
| 166 | + " raise ValueError(\"Invalid clause format\")\n", |
| 167 | + " \n", |
| 168 | + " field_cast = ''\n", |
| 169 | + " if isinstance(value, int):\n", |
| 170 | + " field_cast = '::int'\n", |
| 171 | + " elif isinstance(value, float):\n", |
| 172 | + " field_cast = '::numeric' \n", |
| 173 | + "\n", |
| 174 | + " index = len(params)+1\n", |
| 175 | + " param_name = f\"${index}\"\n", |
| 176 | + " where_conditions.append(f\"(metadata->>'{field}'){field_cast} {operator} {param_name}\")\n", |
| 177 | + " params.append(value) \n", |
| 178 | + "\n", |
| 179 | + " if self.operator == 'NOT':\n", |
| 180 | + " or_clauses = (\" OR \").join(where_conditions)\n", |
| 181 | + " #use IS DISTINCT FROM to treat all-null clauses as False and pass the filter\n", |
| 182 | + " where_clause = f\"TRUE IS DISTINCT FROM ({or_clauses})\"\n", |
| 183 | + " else:\n", |
| 184 | + " where_clause = (\" \"+self.operator+\" \").join(where_conditions) \n", |
| 185 | + " return where_clause, params" |
| 186 | + ] |
| 187 | + }, |
91 | 188 | {
|
92 | 189 | "cell_type": "code",
|
93 | 190 | "execution_count": null,
|
|
260 | 357 | "\n",
|
261 | 358 | " return (where, params)\n",
|
262 | 359 | "\n",
|
263 |
| - " def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
| 360 | + " def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, predicates: Optional[Predicates] = None) -> Tuple[str, List]:\n", |
264 | 361 | " \"\"\"\n",
|
265 | 362 | " Generates a similarity query.\n",
|
266 | 363 | "\n",
|
|
283 | 380 | " distance = \"-1.0\"\n",
|
284 | 381 | " order_by_clause = \"\"\n",
|
285 | 382 | "\n",
|
286 |
| - " (where, params) = self._where_clause_for_filter(params, filter)\n", |
| 383 | + " where_clauses = []\n", |
| 384 | + " if filter is not None:\n", |
| 385 | + " (where_filter, params) = self._where_clause_for_filter(params, filter)\n", |
| 386 | + " where_clauses.append(where_filter)\n", |
| 387 | + "\n", |
| 388 | + " if predicates is not None:\n", |
| 389 | + " (where_predicates, params) = predicates.build_query(params)\n", |
| 390 | + " where_clauses.append(where_predicates)\n", |
| 391 | + " \n", |
| 392 | + " if len(where_clauses) > 0:\n", |
| 393 | + " where = \" AND \".join(where_clauses)\n", |
| 394 | + " else:\n", |
| 395 | + " where = \"TRUE\"\n", |
287 | 396 | "\n",
|
288 | 397 | " query = '''\n",
|
289 | 398 | " SELECT\n",
|
|
534 | 643 | " query_embedding: Optional[List[float]] = None,\n",
|
535 | 644 | " # The number of nearest neighbors to retrieve. Default is 10.\n",
|
536 | 645 | " limit: int = 10,\n",
|
537 |
| - " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.\n", |
| 646 | + " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", |
| 647 | + " predicates: Optional[Predicates] = None): # A filter for metadata. Default is None.\n", |
538 | 648 | " \"\"\"\n",
|
539 | 649 | " Retrieves similar records using a similarity query.\n",
|
540 | 650 | "\n",
|
541 | 651 | " Returns:\n",
|
542 | 652 | " List: List of similar records.\n",
|
543 | 653 | " \"\"\"\n",
|
544 | 654 | " (query, params) = self.builder.search_query(\n",
|
545 |
| - " query_embedding, limit, filter)\n", |
| 655 | + " query_embedding, limit, filter, predicates)\n", |
546 | 656 | " async with await self.connect() as pool:\n",
|
547 | 657 | " return await pool.fetch(query, *params)"
|
548 | 658 | ]
|
|
653 | 763 | "\n",
|
654 | 764 | "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
|
655 | 765 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
|
656 |
| - "> ne)\n", |
| 766 | + "> ne, predicates:Optional[__main__.Predicates]=None)\n", |
657 | 767 | "\n",
|
658 | 768 | "Retrieves similar records using a similarity query.\n",
|
659 | 769 | "\n",
|
|
669 | 779 | "\n",
|
670 | 780 | "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
|
671 | 781 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n",
|
672 |
| - "> ne)\n", |
| 782 | + "> ne, predicates:Optional[__main__.Predicates]=None)\n", |
673 | 783 | "\n",
|
674 | 784 | "Retrieves similar records using a similarity query.\n",
|
675 | 785 | "\n",
|
|
722 | 832 | "\n",
|
723 | 833 | "await vec.upsert([\n",
|
724 | 834 | " (uuid.uuid4(), '''{\"key\":\"val\"}''', \"the brown fox\", [1.0, 1.3]),\n",
|
725 |
| - " (uuid.uuid4(), '''{\"key\":\"val2\"}''', \"the brown fox\", [1.0, 1.4]),\n", |
| 835 | + " (uuid.uuid4(), '''{\"key\":\"val2\", \"key_10\": \"10\", \"key_11\": \"11.3\"}''', \"the brown fox\", [1.0, 1.4]),\n", |
726 | 836 | " (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.5]),\n",
|
727 | 837 | " (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.6]),\n",
|
728 | 838 | " (uuid.uuid4(), '''{\"key2\":\"val\"}''', \"the brown fox\", [1.0, 1.6]),\n",
|
|
769 | 879 | "\n",
|
770 | 880 | "assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
|
771 | 881 | "\n",
|
| 882 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n", |
| 883 | + "assert len(rec) == 1\n", |
| 884 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"==\", \"val2\")))\n", |
| 885 | + "assert len(rec) == 1\n", |
| 886 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)))\n", |
| 887 | + "assert len(rec) == 1\n", |
| 888 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 10)))\n", |
| 889 | + "assert len(rec) == 0\n", |
| 890 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10)))\n", |
| 891 | + "assert len(rec) == 1\n", |
| 892 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10.0)))\n", |
| 893 | + "assert len(rec) == 1\n", |
| 894 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<=\", 11.3)))\n", |
| 895 | + "assert len(rec) == 1\n", |
| 896 | + "rec = await vec.search(limit=4, predicates=Predicates((\"key_11\", \">=\", 11.29999)))\n", |
| 897 | + "assert len(rec) == 1\n", |
| 898 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<\", 11.299999)))\n", |
| 899 | + "assert len(rec) == 0\n", |
| 900 | + "\n", |
| 901 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(*[(\"key\", \"val2\"), (\"key_10\", \"<\", 100)]))\n", |
| 902 | + "assert len(rec) == 1\n", |
| 903 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100), operator='AND'))\n", |
| 904 | + "assert len(rec) == 1\n", |
| 905 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_2\", \"val_2\"), operator='OR'))\n", |
| 906 | + "assert len(rec) == 2\n", |
| 907 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) & (Predicates((\"key\", \"val2\")) | Predicates((\"key_2\", \"val_2\")))) \n", |
| 908 | + "assert len(rec) == 1\n", |
| 909 | + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) and (Predicates((\"key\", \"val2\")) or Predicates((\"key_2\", \"val_2\")))) \n", |
| 910 | + "assert len(rec) == 1\n", |
| 911 | + "rec = await vec.search(limit=4, predicates=~Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100)))\n", |
| 912 | + "assert len(rec) == 4\n", |
| 913 | + "\n", |
| 914 | + "\n", |
772 | 915 | "try:\n",
|
773 | 916 | " # can't upsert using both keys and dictionaries\n",
|
774 | 917 | " await vec.upsert([\n",
|
|
1069 | 1212 | " with conn.cursor() as cur:\n",
|
1070 | 1213 | " cur.execute(query)\n",
|
1071 | 1214 | "\n",
|
1072 |
| - " def search(self, query_embedding: Optional[List[float]] = None, limit: int = 10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n", |
| 1215 | + " def search(self, \n", |
| 1216 | + " query_embedding: Optional[List[float]] = None, \n", |
| 1217 | + " limit: int = 10, \n", |
| 1218 | + " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", |
| 1219 | + " predicates: Optional[Predicates] = None):\n", |
1073 | 1220 | " \"\"\"\n",
|
1074 | 1221 | " Retrieves similar records using a similarity query.\n",
|
1075 | 1222 | "\n",
|
|
1087 | 1234 | " query_embedding_np = None\n",
|
1088 | 1235 | "\n",
|
1089 | 1236 | " (query, params) = self.builder.search_query(\n",
|
1090 |
| - " query_embedding_np, limit, filter)\n", |
| 1237 | + " query_embedding_np, limit, filter, predicates)\n", |
1091 | 1238 | " query, params = self._translate_to_pyformat(query, params)\n",
|
1092 | 1239 | " with self.connect() as conn:\n",
|
1093 | 1240 | " with conn.cursor() as cur:\n",
|
|
1207 | 1354 | "\n",
|
1208 | 1355 | "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
|
1209 | 1356 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
|
1210 |
| - "> e)\n", |
| 1357 | + "> e, predicates:Optional[__main__.Predicates]=None)\n", |
1211 | 1358 | "\n",
|
1212 | 1359 | "Retrieves similar records using a similarity query.\n",
|
1213 | 1360 | "\n",
|
|
1228 | 1375 | "\n",
|
1229 | 1376 | "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n",
|
1230 | 1377 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n",
|
1231 |
| - "> e)\n", |
| 1378 | + "> e, predicates:Optional[__main__.Predicates]=None)\n", |
1232 | 1379 | "\n",
|
1233 | 1380 | "Retrieves similar records using a similarity query.\n",
|
1234 | 1381 | "\n",
|
|
1363 | 1510 | "assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)\n",
|
1364 | 1511 | "assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
|
1365 | 1512 | "\n",
|
| 1513 | + "rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n", |
| 1514 | + "assert len(rec) == 1\n", |
| 1515 | + "\n", |
1366 | 1516 | "rec = vec.search([1.0, 2.0], limit=4, filter=[\n",
|
1367 | 1517 | " {\"key_1\": \"val_1\"}, {\"key2\": \"val2\"}])\n",
|
1368 | 1518 | "len(rec) == 2\n",
|
|
0 commit comments