|
48 | 48 | OPENAI_MODEL, |
49 | 49 | SECRET_NAME, |
50 | 50 | SECRET_NAME_ELASTICSEARCH, |
| 51 | + ZENML_CHATBOT_MODEL, |
51 | 52 | ) |
52 | 53 | from pgvector.psycopg2 import register_vector |
53 | 54 | from psycopg2.extensions import connection |
@@ -263,25 +264,59 @@ def get_db_conn() -> connection: |
263 | 264 | return psycopg2.connect(**CONNECTION_DETAILS) |
264 | 265 |
|
265 | 266 |
|
266 | | -def get_topn_similar_docs( |
267 | | - query_embedding: List[float], |
268 | | - es_client: Elasticsearch, |
269 | | - n: int = 5, |
270 | | - include_metadata: bool = False, |
271 | | - only_urls: bool = False, |
272 | | -) -> List[Tuple]: |
273 | | - """Fetches the top n most similar documents to the given query embedding from the database. |
| 267 | +def get_topn_similar_docs_pgvector( |
| 268 | + query_embedding: List[float], |
| 269 | + conn: psycopg2.extensions.connection, |
| 270 | + n: int = 5, |
| 271 | + include_metadata: bool = False, |
| 272 | + only_urls: bool = False |
| 273 | + ) -> List[Tuple]: |
| 274 | + """Fetches the top n most similar documents to the given query embedding from the PostgreSQL database. |
274 | 275 |
|
275 | 276 | Args: |
276 | 277 | query_embedding (list): The query embedding to compare against. |
277 | 278 | conn (psycopg2.extensions.connection): The database connection object. |
278 | | - n (int, optional): The number of similar documents to fetch. Defaults to |
279 | | - 5. |
280 | | - include_metadata (bool, optional): Whether to include metadata in the |
281 | | - results. Defaults to False. |
| 279 | + n (int, optional): The number of similar documents to fetch. Defaults to 5. |
| 280 | + include_metadata (bool, optional): Whether to include metadata in the results. Defaults to False. |
| 281 | + only_urls (bool, optional): Whether to only return URLs in the results. Defaults to False. |
| 282 | + """ |
| 283 | + embedding_array = np.array(query_embedding) |
| 284 | + register_vector(conn) |
| 285 | + cur = conn.cursor() |
| 286 | + |
| 287 | + if include_metadata: |
| 288 | + cur.execute( |
| 289 | + f"SELECT content, url, parent_section FROM embeddings ORDER BY embedding <=> %s LIMIT {n}", |
| 290 | + (embedding_array,), |
| 291 | + ) |
| 292 | + elif only_urls: |
| 293 | + cur.execute( |
| 294 | + f"SELECT url FROM embeddings ORDER BY embedding <=> %s LIMIT {n}", |
| 295 | + (embedding_array,), |
| 296 | + ) |
| 297 | + else: |
| 298 | + cur.execute( |
| 299 | + f"SELECT content FROM embeddings ORDER BY embedding <=> %s LIMIT {n}", |
| 300 | + (embedding_array,), |
| 301 | + ) |
282 | 302 |
|
283 | | - Returns: |
284 | | - list: A list of tuples containing the content and metadata (if include_metadata is True) of the top n most similar documents. |
| 303 | + return cur.fetchall() |
| 304 | + |
| 305 | +def get_topn_similar_docs_elasticsearch( |
| 306 | + query_embedding: List[float], |
| 307 | + es_client: Elasticsearch, |
| 308 | + n: int = 5, |
| 309 | + include_metadata: bool = False, |
| 310 | + only_urls: bool = False |
| 311 | + ) -> List[Tuple]: |
| 312 | + """Fetches the top n most similar documents to the given query embedding from the Elasticsearch index. |
| 313 | +
|
| 314 | + Args: |
| 315 | + query_embedding (list): The query embedding to compare against. |
| 316 | + es_client (Elasticsearch): The Elasticsearch client. |
| 317 | + n (int, optional): The number of similar documents to fetch. Defaults to 5. |
| 318 | + include_metadata (bool, optional): Whether to include metadata in the results. Defaults to False. |
| 319 | + only_urls (bool, optional): Whether to only return URLs in the results. Defaults to False. |
285 | 320 | """ |
286 | 321 | index_name = "zenml_docs" |
287 | 322 |
|
@@ -329,6 +364,35 @@ def get_topn_similar_docs( |
329 | 364 |
|
330 | 365 | return results |
331 | 366 |
|
| 367 | +def get_topn_similar_docs( |
| 368 | + query_embedding: List[float], |
| 369 | + conn: psycopg2.extensions.connection = None, |
| 370 | + es_client: Elasticsearch = None, |
| 371 | + n: int = 5, |
| 372 | + include_metadata: bool = False, |
| 373 | + only_urls: bool = False, |
| 374 | +) -> List[Tuple]: |
| 375 | + """Fetches the top n most similar documents to the given query embedding from the database. |
| 376 | +
|
| 377 | + Args: |
| 378 | + query_embedding (list): The query embedding to compare against. |
| 379 | + conn (psycopg2.extensions.connection): The database connection object. |
| 380 | + n (int, optional): The number of similar documents to fetch. Defaults to |
| 381 | + 5. |
| 382 | + include_metadata (bool, optional): Whether to include metadata in the |
| 383 | + results. Defaults to False. |
| 384 | +
|
| 385 | + Returns: |
| 386 | + list: A list of tuples containing the content and metadata (if include_metadata is True) of the top n most similar documents. |
| 387 | + """ |
| 388 | + if conn is None and es_client is None: |
| 389 | + raise ValueError("Either conn or es_client must be provided") |
| 390 | + |
| 391 | + if conn is not None: |
| 392 | + return get_topn_similar_docs_pgvector(query_embedding, conn, n, include_metadata, only_urls) |
| 393 | + |
| 394 | + if es_client is not None: |
| 395 | + return get_topn_similar_docs_elasticsearch(query_embedding, es_client, n, include_metadata, only_urls) |
332 | 396 |
|
333 | 397 | def get_completion_from_messages( |
334 | 398 | messages, model=OPENAI_MODEL, temperature=0.4, max_tokens=1000 |
@@ -367,6 +431,18 @@ def get_embeddings(text): |
367 | 431 | model = SentenceTransformer(EMBEDDINGS_MODEL) |
368 | 432 | return model.encode(text) |
369 | 433 |
|
| 434 | +def find_vectorstore_name() -> str: |
| 435 | + """Finds the name of the vector store used for the given embeddings model. |
| 436 | +
|
| 437 | + Returns: |
| 438 | + str: The name of the vector store. |
| 439 | + """ |
| 440 | + from zenml.client import Client |
| 441 | + client = Client() |
| 442 | + model = client.get_model_version(ZENML_CHATBOT_MODEL, model_version_name_or_number_or_id="v0.68.1-dev") |
| 443 | + |
| 444 | + return model.run_metadata["vector_store"].value["name"] |
| 445 | + |
370 | 446 |
|
371 | 447 | def rerank_documents( |
372 | 448 | query: str, documents: List[Tuple], reranker_model: str = "flashrank" |
@@ -420,11 +496,20 @@ def process_input_with_retrieval( |
420 | 496 | str: The processed output. |
421 | 497 | """ |
422 | 498 | delimiter = "```" |
| 499 | + es_client = None |
| 500 | + conn = None |
| 501 | + |
| 502 | + vector_store_name = find_vectorstore_name() |
| 503 | + if vector_store_name == "pgvector": |
| 504 | + conn = get_db_conn() |
| 505 | + else: |
| 506 | + es_client = get_es_client() |
423 | 507 |
|
424 | 508 | # Step 1: Get documents related to the user input from database |
425 | 509 | related_docs = get_topn_similar_docs( |
426 | 510 | get_embeddings(input), |
427 | | - get_es_client(), |
| 511 | + conn=conn, |
| 512 | + es_client=es_client, |
428 | 513 | n=n_items_retrieved, |
429 | 514 | include_metadata=use_reranking, |
430 | 515 | ) |
|
0 commit comments