Skip to content

Commit 06c0bae

Browse files
committed
Typing fixes
1 parent 645bab4 commit 06c0bae

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

nbs/00_vector.ipynb

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@
6868
"import asyncpg\n",
6969
"import uuid\n",
7070
"from pgvector.asyncpg import register_vector\n",
71-
"from typing import (List, Optional, Union, Dict, Tuple)\n",
72-
"import json "
71+
"from typing import (List, Optional, Union, Dict, Tuple, Any)\n",
72+
"import json\n",
73+
"import numpy as np "
7374
]
7475
},
7576
{
@@ -192,7 +193,7 @@
192193
" return (query, [id])\n",
193194
"\n",
194195
" def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n",
195-
" params = []\n",
196+
" params: List[Any] = []\n",
196197
" (where, params) = self._where_clause_for_filter(params, filter)\n",
197198
" query = \"DELETE FROM {table_name} WHERE {where};\".format(table_name=self._quote_ident(self.table_name), where=where)\n",
198199
" return (query, params) \n",
@@ -247,7 +248,7 @@
247248
"\n",
248249
" return (where, params) \n",
249250
"\n",
250-
" def search_query(self, query_embedding: Optional[List[float]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
251+
" def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n",
251252
" \"\"\"\n",
252253
" Generates a similarity query.\n",
253254
"\n",
@@ -259,7 +260,7 @@
259260
" Returns:\n",
260261
" Tuple[str, List]: A tuple containing the query and parameters.\n",
261262
" \"\"\"\n",
262-
" params = []\n",
263+
" params: List[Any] = []\n",
263264
" if query_embedding is not None:\n",
264265
" distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
265266
" params = params + [query_embedding]\n",
@@ -816,7 +817,7 @@
816817
"source": [
817818
"#| export\n",
818819
"class Sync:\n",
819-
" translated_queries = {}\n",
820+
" translated_queries: Dict[str, str] = {}\n",
820821
" \n",
821822
" def __init__(\n",
822823
" self,\n",
@@ -1033,9 +1034,11 @@
10331034
" List: List of similar records.\n",
10341035
" \"\"\"\n",
10351036
" if query_embedding is not None:\n",
1036-
" query_embedding = np.array(query_embedding)\n",
1037+
" query_embedding_np = np.array(query_embedding)\n",
1038+
" else:\n",
1039+
" query_embedding_np = None \n",
10371040
" \n",
1038-
" (query, params) = self.builder.search_query(query_embedding, limit, filter)\n",
1041+
" (query, params) = self.builder.search_query(query_embedding_np, limit, filter)\n",
10391042
" query, params = self._translate_to_pyformat(query, params)\n",
10401043
" with self.connect() as conn:\n",
10411044
" with conn.cursor() as cur:\n",

timescale_vector/client.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import asyncpg
99
import uuid
1010
from pgvector.asyncpg import register_vector
11-
from typing import (List, Optional, Union, Dict, Tuple)
12-
import json
11+
from typing import (List, Optional, Union, Dict, Tuple, Any)
12+
import json
13+
import numpy as np
1314

1415
# %% ../nbs/00_vector.ipynb 7
1516
SEARCH_RESULT_ID_IDX = 0
@@ -117,7 +118,7 @@ def delete_by_ids_query(self, id: List[uuid.UUID]) -> Tuple[str, List]:
117118
return (query, [id])
118119

119120
def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:
120-
params = []
121+
params: List[Any] = []
121122
(where, params) = self._where_clause_for_filter(params, filter)
122123
query = "DELETE FROM {table_name} WHERE {where};".format(table_name=self._quote_ident(self.table_name), where=where)
123124
return (query, params)
@@ -172,7 +173,7 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
172173

173174
return (where, params)
174175

175-
def search_query(self, query_embedding: Optional[List[float]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:
176+
def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:
176177
"""
177178
Generates a similarity query.
178179
@@ -184,7 +185,7 @@ def search_query(self, query_embedding: Optional[List[float]], limit: int=10, fi
184185
Returns:
185186
Tuple[str, List]: A tuple containing the query and parameters.
186187
"""
187-
params = []
188+
params: List[Any] = []
188189
if query_embedding is not None:
189190
distance = "embedding {op} ${index}".format(op=self.distance_type, index=len(params)+1)
190191
params = params + [query_embedding]
@@ -397,7 +398,7 @@ async def search(self,
397398

398399
# %% ../nbs/00_vector.ipynb 20
399400
class Sync:
400-
translated_queries = {}
401+
translated_queries: Dict[str, str] = {}
401402

402403
def __init__(
403404
self,
@@ -614,9 +615,11 @@ def search(self, query_embedding: Optional[List[float]]=None, limit: int=10, fil
614615
List: List of similar records.
615616
"""
616617
if query_embedding is not None:
617-
query_embedding = np.array(query_embedding)
618+
query_embedding_np = np.array(query_embedding)
619+
else:
620+
query_embedding_np = None
618621

619-
(query, params) = self.builder.search_query(query_embedding, limit, filter)
622+
(query, params) = self.builder.search_query(query_embedding_np, limit, filter)
620623
query, params = self._translate_to_pyformat(query, params)
621624
with self.connect() as conn:
622625
with conn.cursor() as cur:

0 commit comments

Comments
 (0)