|
26 | 26 | from zenml.client import Client |
27 | 27 |
|
28 | 28 | from utils.openai_utils import get_openai_api_key |
29 | | - |
| 29 | +import pinecone |
| 30 | +from pinecone import Pinecone |
30 | 31 | # Configure logging levels for specific modules |
31 | 32 | logging.getLogger("pytorch").setLevel(logging.CRITICAL) |
32 | 33 | logging.getLogger("sentence-transformers").setLevel(logging.CRITICAL) |
|
37 | 38 | logging.getLogger().setLevel(logging.ERROR) |
38 | 39 |
|
39 | 40 | import re |
40 | | -from typing import List, Tuple |
| 41 | +from typing import List, Tuple, Optional |
41 | 42 |
|
42 | 43 | import litellm |
43 | 44 | import numpy as np |
|
49 | 50 | OPENAI_MODEL, |
50 | 51 | SECRET_NAME, |
51 | 52 | SECRET_NAME_ELASTICSEARCH, |
| 53 | + SECRET_NAME_PINECONE, |
52 | 54 | ZENML_CHATBOT_MODEL_NAME, |
53 | 55 | ZENML_CHATBOT_MODEL_VERSION, |
54 | 56 | ) |
@@ -277,6 +279,20 @@ def get_db_conn() -> connection: |
277 | 279 | raise |
278 | 280 |
|
279 | 281 |
|
| 282 | +def get_pinecone_client() -> pinecone.Index: |
| 283 | + """Get a Pinecone index client. |
| 284 | +
|
| 285 | + Returns: |
| 286 | + pinecone.Index: A Pinecone index client. |
| 287 | + """ |
| 288 | + client = Client() |
| 289 | + pinecone_api_key = client.get_secret(SECRET_NAME_PINECONE).secret_values["pinecone_api_key"] |
| 290 | + index_name = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs") |
| 291 | + |
| 292 | + pc = Pinecone(api_key=pinecone_api_key) |
| 293 | + return pc.Index(index_name) |
| 294 | + |
| 295 | + |
280 | 296 | def get_topn_similar_docs_pgvector( |
281 | 297 | query_embedding: List[float], |
282 | 298 | conn: psycopg2.extensions.connection, |
@@ -384,39 +400,89 @@ def get_topn_similar_docs_elasticsearch( |
384 | 400 | return results |
385 | 401 |
|
386 | 402 |
|
387 | | -def get_topn_similar_docs( |
| 403 | +def get_topn_similar_docs_pinecone( |
388 | 404 | query_embedding: List[float], |
389 | | - conn: psycopg2.extensions.connection = None, |
390 | | - es_client: Elasticsearch = None, |
| 405 | + pinecone_index: pinecone.Index, |
391 | 406 | n: int = 5, |
392 | 407 | include_metadata: bool = False, |
393 | 408 | only_urls: bool = False, |
394 | 409 | ) -> List[Tuple]: |
395 | | - """Fetches the top n most similar documents to the given query embedding from the database. |
| 410 | + """Get the top N most similar documents from Pinecone. |
396 | 411 |
|
397 | 412 | Args: |
398 | | - query_embedding (list): The query embedding to compare against. |
399 | | - conn (psycopg2.extensions.connection): The database connection object. |
400 | | - n (int, optional): The number of similar documents to fetch. Defaults to |
401 | | - 5. |
402 | | - include_metadata (bool, optional): Whether to include metadata in the |
403 | | - results. Defaults to False. |
| 413 | + query_embedding (List[float]): The query embedding vector. |
| 414 | + pinecone_index (pinecone.Index): The Pinecone index client. |
| 415 | + n (int, optional): Number of similar documents to return. Defaults to 5. |
| 416 | + include_metadata (bool, optional): Whether to include metadata in results. Defaults to False. |
| 417 | + only_urls (bool, optional): Whether to return only URLs. Defaults to False. |
404 | 418 |
|
405 | 419 | Returns: |
406 | | - list: A list of tuples containing the content and metadata (if include_metadata is True) of the top n most similar documents. |
| 420 | + List[Tuple]: List of tuples containing document content and similarity scores. |
407 | 421 | """ |
408 | | - if conn is None and es_client is None: |
409 | | - raise ValueError("Either conn or es_client must be provided") |
| 422 | + # Query the index |
| 423 | + results = pinecone_index.query( |
| 424 | + vector=query_embedding, |
| 425 | + top_k=n, |
| 426 | + include_metadata=True |
| 427 | + ) |
410 | 428 |
|
411 | | - if conn is not None: |
412 | | - return get_topn_similar_docs_pgvector( |
413 | | - query_embedding, conn, n, include_metadata, only_urls |
414 | | - ) |
| 429 | + # Process results |
| 430 | + similar_docs = [] |
| 431 | + for match in results.matches: |
| 432 | + score = match.score |
| 433 | + metadata = match.metadata |
| 434 | + |
| 435 | + if only_urls: |
| 436 | + similar_docs.append((metadata["url"], score)) |
| 437 | + else: |
| 438 | + content = metadata["page_content"] |
| 439 | + if include_metadata: |
| 440 | + content = f"{metadata['filename']} - {metadata['parent_section']}: {content}" |
| 441 | + similar_docs.append((content, score)) |
| 442 | + |
| 443 | + return similar_docs |
| 444 | + |
| 445 | + |
| 446 | +def get_topn_similar_docs( |
| 447 | + query_embedding: List[float], |
| 448 | + conn: Optional[psycopg2.extensions.connection] = None, |
| 449 | + es_client: Optional[Elasticsearch] = None, |
| 450 | + pinecone_index: Optional[pinecone.Index] = None, |
| 451 | + n: int = 5, |
| 452 | + include_metadata: bool = False, |
| 453 | + only_urls: bool = False, |
| 454 | +) -> List[Tuple]: |
| 455 | + """Get the top N most similar documents from the vector store. |
| 456 | +
|
| 457 | + Args: |
| 458 | + query_embedding (List[float]): The query embedding vector. |
| 459 | + conn (Optional[psycopg2.extensions.connection], optional): PostgreSQL connection. Defaults to None. |
| 460 | + es_client (Optional[Elasticsearch], optional): Elasticsearch client. Defaults to None. |
| 461 | + pinecone_index (Optional[pinecone.Index], optional): Pinecone index client. Defaults to None. |
| 462 | + n (int, optional): Number of similar documents to return. Defaults to 5. |
| 463 | + include_metadata (bool, optional): Whether to include metadata in results. Defaults to False. |
| 464 | + only_urls (bool, optional): Whether to return only URLs. Defaults to False. |
415 | 465 |
|
| 466 | + Returns: |
| 467 | + List[Tuple]: List of tuples containing document content and similarity scores. |
| 468 | +
|
| 469 | + Raises: |
| 470 | + ValueError: If no valid vector store client is provided. |
| 471 | + """ |
416 | 472 | if es_client is not None: |
417 | 473 | return get_topn_similar_docs_elasticsearch( |
418 | 474 | query_embedding, es_client, n, include_metadata, only_urls |
419 | 475 | ) |
| 476 | + elif conn is not None: |
| 477 | + return get_topn_similar_docs_pgvector( |
| 478 | + query_embedding, conn, n, include_metadata, only_urls |
| 479 | + ) |
| 480 | + elif pinecone_index is not None: |
| 481 | + return get_topn_similar_docs_pinecone( |
| 482 | + query_embedding, pinecone_index, n, include_metadata, only_urls |
| 483 | + ) |
| 484 | + else: |
| 485 | + raise ValueError("No valid vector store client provided") |
420 | 486 |
|
421 | 487 |
|
422 | 488 | def get_completion_from_messages( |
@@ -525,32 +591,46 @@ def process_input_with_retrieval( |
525 | 591 | str: The processed output. |
526 | 592 | """ |
527 | 593 | delimiter = "```" |
528 | | - es_client = None |
529 | | - conn = None |
| 594 | + # Get embeddings for the query |
| 595 | + query_embedding = get_embeddings(input) |
530 | 596 |
|
531 | | - vector_store_name = find_vectorstore_name() |
532 | | - if vector_store_name == "pgvector": |
533 | | - conn = get_db_conn() |
534 | | - else: |
| 597 | + # Get similar documents based on the vector store being used |
| 598 | + vector_store = find_vectorstore_name() |
| 599 | + if vector_store == "elasticsearch": |
535 | 600 | es_client = get_es_client() |
| 601 | + similar_docs = get_topn_similar_docs( |
| 602 | + query_embedding=query_embedding, |
| 603 | + es_client=es_client, |
| 604 | + n=n_items_retrieved, |
| 605 | + include_metadata=True, |
| 606 | + ) |
| 607 | + elif vector_store == "pinecone": |
| 608 | + pinecone_index = get_pinecone_client() |
| 609 | + similar_docs = get_topn_similar_docs( |
| 610 | + query_embedding=query_embedding, |
| 611 | + pinecone_index=pinecone_index, |
| 612 | + n=n_items_retrieved, |
| 613 | + include_metadata=True, |
| 614 | + ) |
| 615 | + else: # pgvector |
| 616 | + conn = get_db_conn() |
| 617 | + similar_docs = get_topn_similar_docs( |
| 618 | + query_embedding=query_embedding, |
| 619 | + conn=conn, |
| 620 | + n=n_items_retrieved, |
| 621 | + include_metadata=True, |
| 622 | + ) |
| 623 | + conn.close() |
536 | 624 |
|
537 | | - # Step 1: Get documents related to the user input from database |
538 | | - related_docs = get_topn_similar_docs( |
539 | | - get_embeddings(input), |
540 | | - conn=conn, |
541 | | - es_client=es_client, |
542 | | - n=n_items_retrieved, |
543 | | - include_metadata=use_reranking, |
544 | | - ) |
545 | | - |
| 625 | + # Rerank documents if enabled |
546 | 626 | if use_reranking: |
547 | 627 | # Rerank the documents based on the input |
548 | 628 | # and take the top 5 only |
549 | 629 | context_content = [ |
550 | | - doc[0] for doc in rerank_documents(input, related_docs)[:5] |
| 630 | + doc[0] for doc in rerank_documents(input, similar_docs)[:5] |
551 | 631 | ] |
552 | 632 | else: |
553 | | - context_content = [doc[0] for doc in related_docs[:5]] |
| 633 | + context_content = [doc[0] for doc in similar_docs[:5]] |
554 | 634 |
|
555 | 635 | # Step 2: Get completion from OpenAI API |
556 | 636 | # Set system message to help set appropriate tone and context for model |
|
0 commit comments