|
68 | 68 | "import asyncpg\n",
|
69 | 69 | "import uuid\n",
|
70 | 70 | "from pgvector.asyncpg import register_vector\n",
|
71 |
| - "from typing import (List, Optional, Union, Dict, Tuple)\n", |
72 |
| - "import json " |
| 71 | + "from typing import (List, Optional, Union, Dict, Tuple, Any)\n", |
| 72 | + "import json\n", |
| 73 | + "import numpy as np " |
73 | 74 | ]
|
74 | 75 | },
|
75 | 76 | {
|
|
192 | 193 | " return (query, [id])\n",
|
193 | 194 | "\n",
|
194 | 195 | " def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n",
|
195 |
| - " params = []\n", |
| 196 | + " params: List[Any] = []\n", |
196 | 197 | " (where, params) = self._where_clause_for_filter(params, filter)\n",
|
197 | 198 | " query = \"DELETE FROM {table_name} WHERE {where};\".format(table_name=self._quote_ident(self.table_name), where=where)\n",
|
198 | 199 | " return (query, params) \n",
|
|
247 | 248 | "\n",
|
248 | 249 | " return (where, params) \n",
|
249 | 250 | "\n",
|
250 |
| - " def search_query(self, query_embedding: Optional[List[float]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
| 251 | + " def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
251 | 252 | " \"\"\"\n",
|
252 | 253 | " Generates a similarity query.\n",
|
253 | 254 | "\n",
|
|
259 | 260 | " Returns:\n",
|
260 | 261 | " Tuple[str, List]: A tuple containing the query and parameters.\n",
|
261 | 262 | " \"\"\"\n",
|
262 |
| - " params = []\n", |
| 263 | + " params: List[Any] = []\n", |
263 | 264 | " if query_embedding is not None:\n",
|
264 | 265 | " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n",
|
265 | 266 | " params = params + [query_embedding]\n",
|
|
816 | 817 | "source": [
|
817 | 818 | "#| export\n",
|
818 | 819 | "class Sync:\n",
|
819 |
| - " translated_queries = {}\n", |
| 820 | + " translated_queries: Dict[str, str] = {}\n", |
820 | 821 | " \n",
|
821 | 822 | " def __init__(\n",
|
822 | 823 | " self,\n",
|
|
1033 | 1034 | " List: List of similar records.\n",
|
1034 | 1035 | " \"\"\"\n",
|
1035 | 1036 | " if query_embedding is not None:\n",
|
1036 |
| - " query_embedding = np.array(query_embedding)\n", |
| 1037 | + " query_embedding_np = np.array(query_embedding)\n", |
| 1038 | + " else:\n", |
| 1039 | + " query_embedding_np = None \n", |
1037 | 1040 | " \n",
|
1038 |
| - " (query, params) = self.builder.search_query(query_embedding, limit, filter)\n", |
| 1041 | + " (query, params) = self.builder.search_query(query_embedding_np, limit, filter)\n", |
1039 | 1042 | " query, params = self._translate_to_pyformat(query, params)\n",
|
1040 | 1043 | " with self.connect() as conn:\n",
|
1041 | 1044 | " with conn.cursor() as cur:\n",
|
|
0 commit comments