Skip to content

Commit 6053c16

Browse files
committed
make filtering safer
1 parent d839964 commit 6053c16

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

nbs/00_vector.ipynb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@
241241
" .format(index_name=self._get_embedding_index_name(), table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)\n",
242242
"\n",
243243
" def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]]) -> Tuple[str, List]:\n",
244+
" if filter == None:\n",
245+
" return (\"TRUE\", params)\n",
246+
"\n",
244247
" if isinstance(filter, dict):\n",
245248
" where = \"metadata @> ${index}\".format(index=len(params)+1)\n",
246249
" json_object = json.dumps(filter)\n",
@@ -253,7 +256,7 @@
253256
" index=len(params) + 1)\n",
254257
" params = params + [any_params]\n",
255258
" else:\n",
256-
" where = \"TRUE\"\n",
259+
" raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n",
257260
"\n",
258261
" return (where, params)\n",
259262
"\n",

timescale_vector/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def create_ivfflat_index_query(self, num_records):
166166
.format(index_name=self._get_embedding_index_name(), table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)
167167

168168
def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]]) -> Tuple[str, List]:
169+
if filter == None:
170+
return ("TRUE", params)
171+
169172
if isinstance(filter, dict):
170173
where = "metadata @> ${index}".format(index=len(params)+1)
171174
json_object = json.dumps(filter)
@@ -178,7 +181,7 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
178181
index=len(params) + 1)
179182
params = params + [any_params]
180183
else:
181-
where = "TRUE"
184+
raise ValueError("Unknown filter type: {filter_type}".format(filter_type=type(filter)))
182185

183186
return (where, params)
184187

0 commit comments

Comments
 (0)