diff --git a/llm-complete-guide/README.md b/llm-complete-guide/README.md index 75f7586e..b0593040 100644 --- a/llm-complete-guide/README.md +++ b/llm-complete-guide/README.md @@ -100,7 +100,7 @@ use for the LLM. When you're ready to make the query, run the following command: ```shell -python run.py query "how do I use a custom materializer inside my own zenml steps? i.e. how do I set it? inside the @step decorator?" --model=gpt4 +python run.py query --query-text "how do I use a custom materializer inside my own zenml steps? i.e. how do I set it? inside the @step decorator?" --model=gpt4 ``` Alternative options for LLMs to use include: @@ -147,13 +147,7 @@ export ZENML_HF_SPACE_NAME= # optional, defaults to "llm-com To deploy the RAG pipeline, you can use the following command: ```shell -python run.py --deploy -``` - -Alternatively, you can run the basic RAG pipeline *and* deploy it in one go: - -```shell -python run.py --rag --deploy +python run.py deploy ``` This will open a Hugging Face space in your browser where you can interact with diff --git a/llm-complete-guide/configs/dev/rag.yaml b/llm-complete-guide/configs/dev/rag.yaml index 9b37781b..3379a686 100644 --- a/llm-complete-guide/configs/dev/rag.yaml +++ b/llm-complete-guide/configs/dev/rag.yaml @@ -1,6 +1,5 @@ enable_cache: False -# environment configuration settings: docker: requirements: diff --git a/llm-complete-guide/constants.py b/llm-complete-guide/constants.py index 7b2767c5..358df3f4 100644 --- a/llm-complete-guide/constants.py +++ b/llm-complete-guide/constants.py @@ -17,7 +17,7 @@ import os # Vector Store constants -CHUNK_SIZE = 2000 +CHUNK_SIZE = 1000 CHUNK_OVERLAP = 50 EMBEDDING_DIMENSIONALITY = ( 384 # Update this to match the dimensionality of the new model @@ -25,6 +25,8 @@ # ZenML constants ZENML_CHATBOT_MODEL = "zenml-docs-qa-chatbot" +ZENML_CHATBOT_MODEL_NAME = "zenml-docs-qa-chatbot" +ZENML_CHATBOT_MODEL_VERSION = "0.71.0-dev" # Scraping constants RATE_LIMIT = 5 # Maximum number of requests per second @@ -35,8 +37,8 @@ MODEL_NAME_MAP = { "gpt4": "gpt-4", "gpt35": "gpt-3.5-turbo", - "claude3": "claude-3-opus-20240229", - "claudehaiku": "claude-3-haiku-20240307", + "claude3": "claude-3-5-sonnet-latest", + "claudehaiku": "claude-3-5-haiku-latest", } # CHUNKING_METHOD = "split-by-document" diff --git a/llm-complete-guide/deployment_hf.py b/llm-complete-guide/deployment_hf.py index 6724fc0f..d19f10b4 100644 --- a/llm-complete-guide/deployment_hf.py +++ b/llm-complete-guide/deployment_hf.py @@ -1,13 +1,44 @@ +import logging + import gradio as gr +from constants import SECRET_NAME from utils.llm_utils import process_input_with_retrieval +from zenml.client import Client +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -def predict(message, history): - return process_input_with_retrieval( - input=message, - n_items_retrieved=20, - use_reranking=True, +# Initialize ZenML client and verify secret access +try: + client = Client() + secret = client.get_secret(SECRET_NAME) + logger.info( + f"Successfully initialized ZenML client and found secret {SECRET_NAME}" ) +except Exception as e: + logger.error(f"Failed to initialize ZenML client or access secret: {e}") + raise RuntimeError(f"Application startup failed: {e}") + + +def predict(message, history): + try: + return process_input_with_retrieval( + input=message, + n_items_retrieved=20, + use_reranking=True, + ) + except Exception as e: + logger.error(f"Error processing message: {e}") + return f"Sorry, I encountered an error: {str(e)}" + +# Launch the Gradio interface +interface = gr.ChatInterface( + predict, + title="ZenML Documentation Assistant", + description="Ask me anything about ZenML!", +) -gr.ChatInterface(predict, type="messages").launch() +if __name__ == "__main__": + interface.launch(server_name="0.0.0.0", share=False) diff --git a/llm-complete-guide/gh_action_rag.py b/llm-complete-guide/gh_action_rag.py index ee8ac86d..e21e9980 100644 --- a/llm-complete-guide/gh_action_rag.py +++ b/llm-complete-guide/gh_action_rag.py @@ -21,12 +21,10 @@ import click import yaml -from zenml.enums import PluginSubType - from pipelines.llm_index_and_evaluate import llm_index_and_evaluate -from zenml.client import Client from zenml import Model -from zenml.exceptions import ZenKeyError +from zenml.client import Client +from zenml.enums import PluginSubType @click.command( @@ -89,7 +87,7 @@ def main( zenml_model_name: Optional[str] = "zenml-docs-qa-rag", zenml_model_version: Optional[str] = None, ): - """ + """ Executes the pipeline to train a basic RAG model. Args: @@ -108,14 +106,14 @@ def main( config = yaml.safe_load(file) # Read the model version from a file in the root of the repo - # called "ZENML_VERSION.txt". + # called "ZENML_VERSION.txt". if zenml_model_version == "staging": postfix = "-rc0" elif zenml_model_version == "production": postfix = "" else: postfix = "-dev" - + if Path("ZENML_VERSION.txt").exists(): with open("ZENML_VERSION.txt", "r") as file: zenml_model_version = file.read().strip() @@ -177,7 +175,7 @@ def main( service_account_id=service_account_id, auth_window=0, flavor="builtin", - action_type=PluginSubType.PIPELINE_RUN + action_type=PluginSubType.PIPELINE_RUN, ).id client.create_trigger( name="Production Trigger LLM-Complete", diff --git a/llm-complete-guide/pipelines/__init__.py b/llm-complete-guide/pipelines/__init__.py index ae127fa3..3e9f4d62 100644 --- a/llm-complete-guide/pipelines/__init__.py +++ b/llm-complete-guide/pipelines/__init__.py @@ -19,5 +19,5 @@ from pipelines.generate_chunk_questions import generate_chunk_questions from pipelines.llm_basic_rag import llm_basic_rag from pipelines.llm_eval import llm_eval +from pipelines.llm_index_and_evaluate import llm_index_and_evaluate from pipelines.rag_deployment import rag_deployment -from pipelines.llm_index_and_evaluate import llm_index_and_evaluate \ No newline at end of file diff --git a/llm-complete-guide/pipelines/finetune_embeddings.py b/llm-complete-guide/pipelines/finetune_embeddings.py index e53ae3f1..19b8b08c 100644 --- a/llm-complete-guide/pipelines/finetune_embeddings.py +++ b/llm-complete-guide/pipelines/finetune_embeddings.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from constants import EMBEDDINGS_MODEL_NAME_ZENML from steps.finetune_embeddings import ( evaluate_base_model, evaluate_finetuned_model, diff --git a/llm-complete-guide/pipelines/llm_basic_rag.py b/llm-complete-guide/pipelines/llm_basic_rag.py index 82a97b21..895c4df3 100644 --- a/llm-complete-guide/pipelines/llm_basic_rag.py +++ b/llm-complete-guide/pipelines/llm_basic_rag.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from litellm import config_path from steps.populate_index import ( generate_embeddings, diff --git a/llm-complete-guide/pipelines/llm_index_and_evaluate.py b/llm-complete-guide/pipelines/llm_index_and_evaluate.py index 16423867..b82c84a3 100644 --- a/llm-complete-guide/pipelines/llm_index_and_evaluate.py +++ b/llm-complete-guide/pipelines/llm_index_and_evaluate.py @@ -15,9 +15,10 @@ # limitations under the License. # -from pipelines import llm_basic_rag, llm_eval from zenml import pipeline +from pipelines import llm_basic_rag, llm_eval + @pipeline def llm_index_and_evaluate() -> None: diff --git a/llm-complete-guide/requirements.txt b/llm-complete-guide/requirements.txt index dc578d9d..9af8f66b 100644 --- a/llm-complete-guide/requirements.txt +++ b/llm-complete-guide/requirements.txt @@ -1,4 +1,4 @@ -zenml[server] +zenml[server]>=0.73.0 ratelimit pgvector psycopg2-binary @@ -21,6 +21,7 @@ torch gradio huggingface-hub elasticsearch +tenacity # optional requirements for S3 artifact store # s3fs>2022.3.0 diff --git a/llm-complete-guide/run.py b/llm-complete-guide/run.py index a2ba1f94..360f8af2 100644 --- a/llm-complete-guide/run.py +++ b/llm-complete-guide/run.py @@ -47,12 +47,12 @@ generate_synthetic_data, llm_basic_rag, llm_eval, - rag_deployment, llm_index_and_evaluate, + rag_deployment, ) from structures import Document -from zenml.materializers.materializer_registry import materializer_registry from zenml import Model +from zenml.materializers.materializer_registry import materializer_registry logger = get_logger(__name__) @@ -136,6 +136,12 @@ default=None, help="Path to config", ) +@click.option( + "--query-text", + "query_text", + default=None, + help="Query text", +) def main( pipeline: str, query_text: Optional[str] = None, @@ -169,9 +175,9 @@ def main( } }, } - + # Read the model version from a file in the root of the repo - # called "ZENML_VERSION.txt". + # called "ZENML_VERSION.txt". if zenml_model_version == "staging": postfix = "-rc0" elif zenml_model_version == "production": @@ -264,7 +270,9 @@ def main( elif pipeline == "embeddings": finetune_embeddings.with_options( - model=zenml_model, config_path=config_path, **embeddings_finetune_args + model=zenml_model, + config_path=config_path, + **embeddings_finetune_args, )() elif pipeline == "chunks": diff --git a/llm-complete-guide/steps/eval_e2e.py b/llm-complete-guide/steps/eval_e2e.py index 8797d66b..319b8cbb 100644 --- a/llm-complete-guide/steps/eval_e2e.py +++ b/llm-complete-guide/steps/eval_e2e.py @@ -16,8 +16,17 @@ import json import logging +import warnings from typing import Annotated, Callable, Tuple +# Suppress the specific FutureWarning about clean_up_tokenization_spaces +warnings.filterwarnings( + "ignore", + message=".*clean_up_tokenization_spaces.*", + category=FutureWarning, + module="transformers.tokenization_utils_base", +) + from datasets import load_dataset from litellm import completion from pydantic import BaseModel, conint @@ -315,13 +324,11 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float: @step -def e2e_evaluation() -> ( - Tuple[ - Annotated[float, "failure_rate_bad_answers"], - Annotated[float, "failure_rate_bad_immediate_responses"], - Annotated[float, "failure_rate_good_responses"], - ] -): +def e2e_evaluation() -> Tuple[ + Annotated[float, "failure_rate_bad_answers"], + Annotated[float, "failure_rate_bad_immediate_responses"], + Annotated[float, "failure_rate_good_responses"], +]: """Executes the end-to-end evaluation step.""" logging.info("Testing bad answers...") failure_rate_bad_answers = run_simple_tests( @@ -352,14 +359,12 @@ def e2e_evaluation() -> ( @step -def e2e_evaluation_llm_judged() -> ( - Tuple[ - Annotated[float, "average_toxicity_score"], - Annotated[float, "average_faithfulness_score"], - Annotated[float, "average_helpfulness_score"], - Annotated[float, "average_relevance_score"], - ] -): +def e2e_evaluation_llm_judged() -> Tuple[ + Annotated[float, "average_toxicity_score"], + Annotated[float, "average_faithfulness_score"], + Annotated[float, "average_helpfulness_score"], + Annotated[float, "average_relevance_score"], +]: """Executes the end-to-end evaluation step. Returns: diff --git a/llm-complete-guide/steps/eval_retrieval.py b/llm-complete-guide/steps/eval_retrieval.py index 2b555b85..3c9c83d9 100644 --- a/llm-complete-guide/steps/eval_retrieval.py +++ b/llm-complete-guide/steps/eval_retrieval.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Annotated, List, Tuple +from functools import partial +from multiprocessing import Pool, cpu_count +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple from datasets import load_dataset +from tenacity import retry, stop_after_attempt, wait_exponential from utils.llm_utils import ( find_vectorstore_name, get_db_conn, @@ -27,14 +31,20 @@ rerank_documents, ) from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) # Adjust logging settings as before logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + level=logging.DEBUG, # Change to DEBUG level + format="%(asctime)s - %(levelname)s - %(message)s", ) - +# Only set external loggers to WARNING logging.getLogger("sentence_transformers").setLevel(logging.WARNING) +logging.getLogger("urllib3").setLevel(logging.WARNING) +logging.getLogger("elasticsearch").setLevel(logging.WARNING) question_doc_pairs = [ { @@ -90,11 +100,11 @@ def query_similar_docs( num_docs = 20 if use_reranking else returned_sample_size # get (content, url) tuples for the top n similar documents top_similar_docs = get_topn_similar_docs( - embedded_question, - conn=conn, - es_client=es_client, - n=num_docs, - include_metadata=True + embedded_question, + conn=conn, + es_client=es_client, + n=num_docs, + include_metadata=True, ) if use_reranking: @@ -108,6 +118,79 @@ def query_similar_docs( return (question, url_ending, urls) +def process_single_pair( + pair: Dict, use_reranking: bool = False +) -> Tuple[bool, str, str, List[str]]: + """Process a single question-document pair. + + Args: + pair: Dictionary containing question and URL ending + use_reranking: Whether to use reranking to improve retrieval + + Returns: + Tuple containing (is_failure, question, url_ending, retrieved_urls) + """ + question, url_ending, urls = query_similar_docs( + pair["question"], pair["url_ending"], use_reranking + ) + is_failure = all(url_ending not in url for url in urls) + return is_failure, question, url_ending, urls + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, +) +def process_single_pair_with_retry( + pair: Dict, use_reranking: bool = False +) -> Tuple[bool, str, str, List[str]]: + """Process a single question-document pair with retry logic. + + Args: + pair: Dictionary containing question and URL ending + use_reranking: Whether to use reranking to improve retrieval + + Returns: + Tuple containing (is_failure, question, url_ending, retrieved_urls) + """ + try: + return process_single_pair(pair, use_reranking) + except Exception as e: + logging.warning( + f"Error processing pair {pair['question']}: {str(e)}. Retrying..." + ) + raise + + +def process_with_progress( + items: List, worker_fn: Callable, n_processes: int, desc: str +) -> List: + """Process items in parallel with progress reporting. + + Args: + items: List of items to process + worker_fn: Worker function to apply to each item + n_processes: Number of processes to use + desc: Description for the progress bar + + Returns: + List of results + """ + logger.info( + f"{desc} - Starting parallel processing with {n_processes} workers" + ) + + results = [] + with Pool(processes=n_processes) as pool: + for i, result in enumerate(pool.imap(worker_fn, items), 1): + results.append(result) + logger.info(f"Completed {i}/{len(items)} tests") + + logger.info(f"{desc} - Completed processing {len(results)} items") + return results + + def test_retrieved_docs_retrieve_best_url( question_doc_pairs: list, use_reranking: bool = False ) -> float: @@ -122,21 +205,46 @@ def test_retrieved_docs_retrieve_best_url( The failure rate of the retrieval test. """ total_tests = len(question_doc_pairs) - failures = 0 + logger.info(f"Starting retrieval test with {total_tests} questions...") - for pair in question_doc_pairs: - question, url_ending, urls = query_similar_docs( - pair["question"], pair["url_ending"], use_reranking + n_processes = max(1, cpu_count() // 2) + worker = partial( + process_single_pair_with_retry, use_reranking=use_reranking + ) + + try: + results = process_with_progress( + question_doc_pairs, + worker, + n_processes, + "Testing document retrieval", + ) + + failures = 0 + logger.info("\nTest Results:") + for is_failure, question, url_ending, urls in results: + if is_failure: + failures += 1 + logger.error( + f"Failed test for question: '{question}'\n" + f"Expected URL ending: {url_ending}\n" + f"Got URLs: {urls}" + ) + else: + logger.info(f"Passed test for question: '{question}'") + + failure_rate = (failures / total_tests) * 100 + logger.info( + f"\nTest Summary:\n" + f"Total tests: {total_tests}\n" + f"Failures: {failures}\n" + f"Failure rate: {failure_rate}%" ) - if all(url_ending not in url for url in urls): - logging.error( - f"Failed for question: {question}. Expected URL ending: {url_ending}. Got: {urls}" - ) - failures += 1 + return round(failure_rate, 2) - logging.info(f"Total tests: {total_tests}. Failures: {failures}") - failure_rate = (failures / total_tests) * 100 - return round(failure_rate, 2) + except Exception as e: + logger.error(f"Error during parallel processing: {str(e)}") + raise def perform_small_retrieval_evaluation(use_reranking: bool) -> float: @@ -158,9 +266,9 @@ def perform_small_retrieval_evaluation(use_reranking: bool) -> float: @step -def retrieval_evaluation_small() -> ( - Annotated[float, "small_failure_rate_retrieval"] -): +def retrieval_evaluation_small() -> Annotated[ + float, "small_failure_rate_retrieval" +]: """Executes the retrieval evaluation step without reranking. Returns: @@ -170,9 +278,9 @@ def retrieval_evaluation_small() -> ( @step -def retrieval_evaluation_small_with_reranking() -> ( - Annotated[float, "small_failure_rate_retrieval_reranking"] -): +def retrieval_evaluation_small_with_reranking() -> Annotated[ + float, "small_failure_rate_retrieval_reranking" +]: """Executes the retrieval evaluation step with reranking. Returns: @@ -181,6 +289,55 @@ def retrieval_evaluation_small_with_reranking() -> ( return perform_small_retrieval_evaluation(use_reranking=True) +def process_single_dataset_item( + item: Dict, use_reranking: bool = False +) -> Tuple[bool, str, str, List[str]]: + """Process a single dataset item. + + Args: + item: Dictionary containing the dataset item with generated questions and filename + use_reranking: Whether to use reranking to improve retrieval + + Returns: + Tuple containing (is_failure, question, url_ending, retrieved_urls) + """ + generated_questions = item["generated_questions"] + question = generated_questions[0] # Assuming only one question per item + url_ending = item["filename"].split("/")[ + -1 + ] # Extract the URL ending from the filename + + _, _, urls = query_similar_docs(question, url_ending, use_reranking) + is_failure = all(url_ending not in url for url in urls) + return is_failure, question, url_ending, urls + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, +) +def process_single_dataset_item_with_retry( + item: Dict, use_reranking: bool = False +) -> Tuple[bool, str, str, List[str]]: + """Process a single dataset item with retry logic. + + Args: + item: Dictionary containing the dataset item + use_reranking: Whether to use reranking to improve retrieval + + Returns: + Tuple containing (is_failure, question, url_ending, retrieved_urls) + """ + try: + return process_single_dataset_item(item, use_reranking) + except Exception as e: + logging.warning( + f"Error processing dataset item: {str(e)}. Retrying..." + ) + raise + + def perform_retrieval_evaluation( sample_size: int, use_reranking: bool ) -> float: @@ -197,28 +354,41 @@ def perform_retrieval_evaluation( sampled_dataset = dataset.shuffle(seed=42).select(range(sample_size)) total_tests = len(sampled_dataset) - failures = 0 - - for item in sampled_dataset: - generated_questions = item["generated_questions"] - question = generated_questions[ - 0 - ] # Assuming only one question per item - url_ending = item["filename"].split("/")[ - -1 - ] # Extract the URL ending from the filename + n_processes = max(1, cpu_count() // 2) + worker = partial( + process_single_dataset_item_with_retry, use_reranking=use_reranking + ) - _, _, urls = query_similar_docs(question, url_ending, use_reranking) + try: + results = process_with_progress( + sampled_dataset, worker, n_processes, "Evaluating retrieval" + ) - if all(url_ending not in url for url in urls): - logging.error( - f"Failed for question: {question}. Expected URL containing: {url_ending}. Got: {urls}" - ) - failures += 1 + failures = 0 + logger.info("\nTest Results:") + for is_failure, question, url_ending, urls in results: + if is_failure: + failures += 1 + logger.error( + f"Failed test for question: '{question}'\n" + f"Expected URL containing: {url_ending}\n" + f"Got URLs: {urls}" + ) + else: + logger.info(f"Passed test for question: '{question}'") + + failure_rate = (failures / total_tests) * 100 + logger.info( + f"\nTest Summary:\n" + f"Total tests: {total_tests}\n" + f"Failures: {failures}\n" + f"Failure rate: {failure_rate}%" + ) + return round(failure_rate, 2) - logging.info(f"Total tests: {total_tests}. Failures: {failures}") - failure_rate = (failures / total_tests) * 100 - return round(failure_rate, 2) + except Exception as e: + logger.error(f"Error during parallel processing: {str(e)}") + raise @step @@ -257,3 +427,230 @@ def retrieval_evaluation_full_with_reranking( ) logging.info(f"Retrieval failure rate with reranking: {failure_rate}%") return failure_rate + + +def process_single_test( + item: Any, + test_function: Callable, +) -> Tuple[bool, str, str, str]: + """Process a single test item. + + Args: + item: The test item to process + test_function: The test function to run + + Returns: + Tuple containing (is_failure, question, keyword, response) + """ + test_result = test_function(item) + return ( + not test_result.success, + test_result.question, + test_result.keyword, + test_result.response, + ) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, +) +def process_single_test_with_retry( + item: Any, + test_function: Callable, +) -> Tuple[bool, str, str, str]: + """Process a single test item with retry logic. + + Args: + item: The test item to process + test_function: The test function to run + + Returns: + Tuple containing (is_failure, question, keyword, response) + """ + try: + return process_single_test(item, test_function) + except Exception as e: + logging.warning(f"Error processing test item: {str(e)}. Retrying...") + raise + + +def run_simple_tests(test_data: list, test_function: Callable) -> float: + """ + Run tests for bad answers in parallel with progress reporting and error handling. + + Args: + test_data (list): The test data. + test_function (function): The test function to run. + + Returns: + float: The failure rate. + """ + total_tests = len(test_data) + n_processes = max(1, cpu_count() // 2) + worker = partial( + process_single_test_with_retry, test_function=test_function + ) + + try: + results = process_with_progress( + test_data, worker, n_processes, "Running tests" + ) + + failures = 0 + logger.info("\nTest Results:") + for is_failure, question, keyword, response in results: + if is_failure: + failures += 1 + logger.error( + f"Failed test for question: '{question}'\n" + f"Found word: '{keyword}'\n" + f"Response: '{response}'" + ) + else: + logger.info(f"Passed test for question: '{question}'") + + failure_rate = (failures / total_tests) * 100 + logger.info( + f"\nTest Summary:\n" + f"Total tests: {total_tests}\n" + f"Failures: {failures}\n" + f"Failure rate: {failure_rate}%" + ) + return round(failure_rate, 2) + + except Exception as e: + logger.error(f"Error during parallel processing: {str(e)}") + raise + + +def process_single_llm_test( + item: Dict, + test_function: Callable, +) -> Tuple[float, float, float, float]: + """Process a single LLM test item. + + Args: + item: Dictionary containing the dataset item + test_function: The test function to run + + Returns: + Tuple containing (toxicity, faithfulness, helpfulness, relevance) scores + """ + # Assuming only one question per item + question = item["generated_questions"][0] + context = item["page_content"] + + try: + result = test_function(question, context) + return ( + result.toxicity, + result.faithfulness, + result.helpfulness, + result.relevance, + ) + except json.JSONDecodeError as e: + logger.error(f"Failed for question: {question}. Error: {e}") + # Return None to indicate this test should be skipped + return None + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, +) +def process_single_llm_test_with_retry( + item: Dict, + test_function: Callable, +) -> Optional[Tuple[float, float, float, float]]: + """Process a single LLM test item with retry logic. + + Args: + item: Dictionary containing the dataset item + test_function: The test function to run + + Returns: + Optional tuple containing (toxicity, faithfulness, helpfulness, relevance) scores + Returns None if the test should be skipped + """ + try: + return process_single_llm_test(item, test_function) + except Exception as e: + logger.warning(f"Error processing LLM test: {str(e)}. Retrying...") + raise + + +def run_llm_judged_tests( + test_function: Callable, + sample_size: int = 10, +) -> Tuple[ + Annotated[float, "average_toxicity_score"], + Annotated[float, "average_faithfulness_score"], + Annotated[float, "average_helpfulness_score"], + Annotated[float, "average_relevance_score"], +]: + """E2E tests judged by an LLM. + + Args: + test_data (list): The test data. + test_function (function): The test function to run. + sample_size (int): The sample size to run the tests on. + + Returns: + Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores. + """ + # Load the dataset from the Hugging Face Hub + dataset = load_dataset("zenml/rag_qa_embedding_questions", split="train") + + # Shuffle the dataset and select a random sample + sampled_dataset = dataset.shuffle(seed=42).select(range(sample_size)) + + n_processes = max(1, cpu_count() // 2) + worker = partial( + process_single_llm_test_with_retry, test_function=test_function + ) + + try: + results = process_with_progress( + sampled_dataset, worker, n_processes, "Running LLM judged tests" + ) + + # Filter out None results (failed tests) + valid_results = [r for r in results if r is not None] + total_tests = len(valid_results) + + if total_tests == 0: + logger.error("All tests failed!") + return 0.0, 0.0, 0.0, 0.0 + + # Calculate totals + total_toxicity = sum(r[0] for r in valid_results) + total_faithfulness = sum(r[1] for r in valid_results) + total_helpfulness = sum(r[2] for r in valid_results) + total_relevance = sum(r[3] for r in valid_results) + + # Calculate averages + average_toxicity_score = total_toxicity / total_tests + average_faithfulness_score = total_faithfulness / total_tests + average_helpfulness_score = total_helpfulness / total_tests + average_relevance_score = total_relevance / total_tests + + logger.info("\nTest Results Summary:") + logger.info(f"Total valid tests: {total_tests}") + logger.info(f"Average toxicity: {average_toxicity_score:.3f}") + logger.info(f"Average faithfulness: {average_faithfulness_score:.3f}") + logger.info(f"Average helpfulness: {average_helpfulness_score:.3f}") + logger.info(f"Average relevance: {average_relevance_score:.3f}") + + return ( + round(average_toxicity_score, 3), + round(average_faithfulness_score, 3), + round(average_helpfulness_score, 3), + round(average_relevance_score, 3), + ) + + except Exception as e: + logger.error(f"Error during parallel processing: {str(e)}") + raise diff --git a/llm-complete-guide/steps/eval_visualisation.py b/llm-complete-guide/steps/eval_visualisation.py index badd62c1..1a582490 100644 --- a/llm-complete-guide/steps/eval_visualisation.py +++ b/llm-complete-guide/steps/eval_visualisation.py @@ -18,7 +18,7 @@ import matplotlib.pyplot as plt import numpy as np from PIL import Image -from zenml import ArtifactConfig, get_step_context, step +from zenml import ArtifactConfig, get_step_context, log_metadata, step def create_image( @@ -124,7 +124,7 @@ def visualize_evaluation_results( Annotated[Image.Image, ArtifactConfig(name="generation_eval_full")], ]: """ - Visualize the evaluation results by creating three separate images. + Visualize the evaluation results by creating three separate images and logging metrics. Args: small_retrieval_eval_failure_rate (float): Small retrieval evaluation failure rate. @@ -145,6 +145,38 @@ def visualize_evaluation_results( step_context = get_step_context() pipeline_run_name = step_context.pipeline_run.name + # Log all metrics as metadata for dashboard visualization + log_metadata( + metadata={ + # Retrieval metrics + "retrieval.small_failure_rate": small_retrieval_eval_failure_rate, + "retrieval.small_failure_rate_reranking": small_retrieval_eval_failure_rate_reranking, + "retrieval.full_failure_rate": full_retrieval_eval_failure_rate, + "retrieval.full_failure_rate_reranking": full_retrieval_eval_failure_rate_reranking, + # Generation failure metrics + "generation.failure_rate_bad_answers": failure_rate_bad_answers, + "generation.failure_rate_bad_immediate": failure_rate_bad_immediate_responses, + "generation.failure_rate_good": failure_rate_good_responses, + # Quality metrics + "quality.toxicity": average_toxicity_score, + "quality.faithfulness": average_faithfulness_score, + "quality.helpfulness": average_helpfulness_score, + "quality.relevance": average_relevance_score, + # Composite scores + "composite.overall_quality": ( + average_faithfulness_score + + average_helpfulness_score + + average_relevance_score + ) + / 3, + "composite.retrieval_effectiveness": ( + (1 - small_retrieval_eval_failure_rate) + + (1 - full_retrieval_eval_failure_rate) + ) + / 2, + } + ) + normalized_scores = [ score / 20 for score in [ diff --git a/llm-complete-guide/steps/finetune_embeddings.py b/llm-complete-guide/steps/finetune_embeddings.py index 3117c473..44a2f707 100644 --- a/llm-complete-guide/steps/finetune_embeddings.py +++ b/llm-complete-guide/steps/finetune_embeddings.py @@ -49,6 +49,7 @@ from sentence_transformers.util import cos_sim from zenml import ArtifactConfig, log_model_metadata, step from zenml.client import Client +from zenml.enums import ArtifactType from zenml.utils.cuda_utils import cleanup_gpu_memory @@ -218,7 +219,7 @@ def finetune( ) -> Annotated[ SentenceTransformer, ArtifactConfig( - is_model_artifact=True, + artifact_type=ArtifactType.MODEL, name="finetuned-model", ), ]: diff --git a/llm-complete-guide/steps/hf_dataset_loader.py b/llm-complete-guide/steps/hf_dataset_loader.py index 5615ba4a..0c777757 100644 --- a/llm-complete-guide/steps/hf_dataset_loader.py +++ b/llm-complete-guide/steps/hf_dataset_loader.py @@ -23,9 +23,9 @@ @step(output_materializers=HFDatasetMaterializer) -def load_hf_dataset() -> ( - Tuple[Annotated[Dataset, "train"], Annotated[Dataset, "test"]] -): +def load_hf_dataset() -> Tuple[ + Annotated[Dataset, "train"], Annotated[Dataset, "test"] +]: train_dataset = load_dataset(DATASET_NAME_DEFAULT, split="train") test_dataset = load_dataset(DATASET_NAME_DEFAULT, split="test") return train_dataset, test_dataset diff --git a/llm-complete-guide/steps/populate_index.py b/llm-complete-guide/steps/populate_index.py index d9a9bd95..6397ebed 100644 --- a/llm-complete-guide/steps/populate_index.py +++ b/llm-complete-guide/steps/populate_index.py @@ -23,26 +23,34 @@ import json import logging import math -from typing import Annotated, Any, Dict, List, Tuple +import warnings from enum import Enum +from typing import Annotated, Any, Dict, List, Tuple + +# Suppress the specific FutureWarning about clean_up_tokenization_spaces +warnings.filterwarnings( + "ignore", + message=".*clean_up_tokenization_spaces.*", + category=FutureWarning, + module="transformers.tokenization_utils_base", +) from constants import ( CHUNK_OVERLAP, CHUNK_SIZE, EMBEDDING_DIMENSIONALITY, EMBEDDINGS_MODEL, + SECRET_NAME, SECRET_NAME_ELASTICSEARCH, - ZENML_CHATBOT_MODEL, ) from pgvector.psycopg2 import register_vector from PIL import Image, ImageDraw, ImageFont from sentence_transformers import SentenceTransformer from structures import Document from utils.llm_utils import get_db_conn, get_es_client, split_documents -from zenml import ArtifactConfig, log_artifact_metadata, step, log_model_metadata -from zenml.metadata.metadata_types import Uri +from zenml import ArtifactConfig, log_metadata, step from zenml.client import Client -from constants import SECRET_NAME +from zenml.metadata.metadata_types import Uri logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -453,11 +461,11 @@ def draw_bar_chart( """Draws a bar chart on the given image.""" # Ensure labels is a list, even if empty labels = labels or [] - + # Skip drawing if no data if not data: return - + max_value = max(data) bar_width = width // len(data) bar_spacing = 10 @@ -487,10 +495,21 @@ def draw_bar_chart( for i, label in enumerate(labels): if label is not None: # Add null check for individual labels font = ImageFont.load_default(size=10) - bbox = draw.textbbox((0, 0), str(label), font=font) # Convert to string + bbox = draw.textbbox( + (0, 0), str(label), font=font + ) # Convert to string label_width = bbox[2] - bbox[0] - label_x = x + i * (bar_width + bar_spacing) + (bar_width - label_width) // 2 - draw.text((label_x, y + height - 15), str(label), font=font, fill="black") + label_x = ( + x + + i * (bar_width + bar_spacing) + + (bar_width - label_width) // 2 + ) + draw.text( + (label_x, y + height - 15), + str(label), + font=font, + fill="black", + ) @step @@ -515,12 +534,13 @@ def preprocess_documents( Exception: If an error occurs during preprocessing. """ try: - log_artifact_metadata( - artifact_name="split_chunks", + log_metadata( metadata={ "chunk_size": CHUNK_SIZE, "chunk_overlap": CHUNK_OVERLAP, }, + artifact_name="split_chunks", + infer_artifact=True, ) document_list: List[Document] = [ @@ -536,9 +556,10 @@ def preprocess_documents( histogram_chart: Image.Image = create_histogram(stats) bar_chart: Image.Image = create_bar_chart(stats) - log_artifact_metadata( + log_metadata( artifact_name="split_chunks", metadata=stats, + infer_artifact=True, ) split_docs_json: str = json.dumps([doc.__dict__ for doc in split_docs]) @@ -566,14 +587,20 @@ def generate_embeddings( Exception: If an error occurs during the generation of embeddings. """ try: + # Initialize the model model = SentenceTransformer(EMBEDDINGS_MODEL) - log_artifact_metadata( - artifact_name="documents_with_embeddings", + # Set clean_up_tokenization_spaces to False on the underlying tokenizer to avoid the warning + if hasattr(model.tokenizer, "clean_up_tokenization_spaces"): + model.tokenizer.clean_up_tokenization_spaces = False + + log_metadata( metadata={ "embedding_type": EMBEDDINGS_MODEL, "embedding_dimensionality": EMBEDDING_DIMENSIONALITY, }, + artifact_name="documents_with_embeddings", + infer_artifact=True, ) # Parse the JSON string into a list of Document objects @@ -600,10 +627,11 @@ class IndexType(Enum): ELASTICSEARCH = "elasticsearch" POSTGRES = "postgres" + @step(enable_cache=False) def index_generator( documents: str, - index_type: IndexType = IndexType.ELASTICSEARCH, + index_type: IndexType = IndexType.POSTGRES, ) -> None: """Generates an index for the given documents. @@ -624,11 +652,12 @@ def index_generator( _index_generator_elastic(documents) else: _index_generator_postgres(documents) - + except Exception as e: logger.error(f"Error in index_generator: {e}") raise + def _index_generator_elastic(documents: str) -> None: """Generates an Elasticsearch index for the given documents.""" try: @@ -647,11 +676,11 @@ def _index_generator_elastic(documents: str) -> None: "type": "dense_vector", "dims": EMBEDDING_DIMENSIONALITY, "index": True, - "similarity": "cosine" + "similarity": "cosine", }, "filename": {"type": "text"}, "parent_section": {"type": "text"}, - "url": {"type": "text"} + "url": {"type": "text"}, } } } @@ -661,50 +690,49 @@ def _index_generator_elastic(documents: str) -> None: # Parse the JSON string into a list of Document objects document_list = [Document(**doc) for doc in json.loads(documents)] operations = [] - + for doc in document_list: content_hash = hashlib.md5( f"{doc.page_content}{doc.filename}{doc.parent_section}{doc.url}".encode() ).hexdigest() - - exists_query = { - "query": { - "term": { - "doc_id": content_hash - } - } - } - + + exists_query = {"query": {"term": {"doc_id": content_hash}}} + if not es.count(index=index_name, body=exists_query)["count"]: - operations.append({ - "index": { - "_index": index_name, - "_id": content_hash + operations.append( + {"index": {"_index": index_name, "_id": content_hash}} + ) + + operations.append( + { + "doc_id": content_hash, + "content": doc.page_content, + "token_count": doc.token_count, + "embedding": doc.embedding, + "filename": doc.filename, + "parent_section": doc.parent_section, + "url": doc.url, } - }) - - operations.append({ - "doc_id": content_hash, - "content": doc.page_content, - "token_count": doc.token_count, - "embedding": doc.embedding, - "filename": doc.filename, - "parent_section": doc.parent_section, - "url": doc.url - }) - + ) + if operations: response = es.bulk(operations=operations, timeout="10m") - - success_count = sum(1 for item in response['items'] if 'index' in item and item['index']['status'] == 201) - failed_count = len(response['items']) - success_count - + + success_count = sum( + 1 + for item in response["items"] + if "index" in item and item["index"]["status"] == 201 + ) + failed_count = len(response["items"]) - success_count + logger.info(f"Successfully indexed {success_count} documents") if failed_count > 0: logger.warning(f"Failed to index {failed_count} documents") - for item in response['items']: - if 'index' in item and item['index']['status'] != 201: - logger.warning(f"Failed to index document: {item['index']['error']}") + for item in response["items"]: + if "index" in item and item["index"]["status"] != 201: + logger.warning( + f"Failed to index document: {item['index']['error']}" + ) else: logger.info("No new documents to index") @@ -714,11 +742,12 @@ def _index_generator_elastic(documents: str) -> None: logger.error(f"Error in Elasticsearch indexing: {e}") raise + def _index_generator_postgres(documents: str) -> None: """Generates a PostgreSQL index for the given documents.""" try: conn = get_db_conn() - + with conn.cursor() as cur: # Install pgvector if not already installed cur.execute("CREATE EXTENSION IF NOT EXISTS vector") @@ -740,7 +769,7 @@ def _index_generator_postgres(documents: str) -> None: conn.commit() register_vector(conn) - + # Parse the JSON string into a list of Document objects document_list = [Document(**doc) for doc in json.loads(documents)] @@ -772,7 +801,6 @@ def _index_generator_postgres(documents: str) -> None: ) conn.commit() - cur.execute("SELECT COUNT(*) as cnt FROM embeddings;") num_records = cur.fetchone()[0] logger.info(f"Number of vector records in table: {num_records}") @@ -797,6 +825,7 @@ def _index_generator_postgres(documents: str) -> None: if conn: conn.close() + def _log_metadata(index_type: IndexType) -> None: """Log metadata about the indexing process.""" prompt = """ @@ -809,9 +838,11 @@ def _log_metadata(index_type: IndexType) -> None: """ client = Client() - + if index_type == IndexType.ELASTICSEARCH: - es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"] + es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values[ + "elasticsearch_host" + ] connection_details = { "host": es_host, "api_key": "*********", @@ -821,14 +852,20 @@ def _log_metadata(index_type: IndexType) -> None: store_name = "pgvector" connection_details = { - "user": client.get_secret(SECRET_NAME).secret_values["supabase_user"], + "user": client.get_secret(SECRET_NAME).secret_values[ + "supabase_user" + ], "password": "**********", - "host": client.get_secret(SECRET_NAME).secret_values["supabase_host"], - "port": client.get_secret(SECRET_NAME).secret_values["supabase_port"], + "host": client.get_secret(SECRET_NAME).secret_values[ + "supabase_host" + ], + "port": client.get_secret(SECRET_NAME).secret_values[ + "supabase_port" + ], "dbname": "postgres", } - log_model_metadata( + log_metadata( metadata={ "embeddings": { "model": EMBEDDINGS_MODEL, @@ -843,4 +880,5 @@ def _log_metadata(index_type: IndexType) -> None: "connection_details": connection_details, }, }, + infer_model=True, ) diff --git a/llm-complete-guide/steps/rag_deployment.py b/llm-complete-guide/steps/rag_deployment.py index 99a8c911..3a0783f1 100644 --- a/llm-complete-guide/steps/rag_deployment.py +++ b/llm-complete-guide/steps/rag_deployment.py @@ -1,26 +1,37 @@ import os import webbrowser +from constants import SECRET_NAME from huggingface_hub import HfApi - from utils.hf_utils import get_hf_token from utils.llm_utils import process_input_with_retrieval from zenml import step from zenml.client import Client from zenml.integrations.registry import integration_registry -secret = Client().get_secret("llm-complete") - +# Try to get from environment first, otherwise fall back to secret store ZENML_API_TOKEN = os.environ.get("ZENML_API_TOKEN") ZENML_STORE_URL = os.environ.get("ZENML_STORE_URL") + +if not ZENML_API_TOKEN or not ZENML_STORE_URL: + # Get ZenML server URL and API token from the secret store + secret = Client().get_secret(SECRET_NAME) + ZENML_API_TOKEN = ZENML_API_TOKEN or secret.secret_values.get( + "zenml_api_token" + ) + ZENML_STORE_URL = ZENML_STORE_URL or secret.secret_values.get( + "zenml_store_url" + ) + SPACE_USERNAME = os.environ.get("ZENML_HF_USERNAME", "zenml") SPACE_NAME = os.environ.get("ZENML_HF_SPACE_NAME", "llm-complete-guide-rag") +SECRET_NAME = os.environ.get("ZENML_PROJECT_SECRET_NAME", "llm-complete") hf_repo_id = f"{SPACE_USERNAME}/{SPACE_NAME}" gcp_reqs = integration_registry.select_integration_requirements("gcp") hf_repo_requirements = f""" -zenml>=0.68.1 +zenml>=0.73.0 ratelimit pgvector psycopg2-binary @@ -38,6 +49,8 @@ datasets torch huggingface-hub +elasticsearch +tenacity {chr(10).join(gcp_reqs)} """ @@ -50,9 +63,7 @@ def predict(message, history): ) -def upload_files_to_repo( - api, repo_id: str, files_mapping: dict, token: str -): +def upload_files_to_repo(api, repo_id: str, files_mapping: dict, token: str): """Upload multiple files to a Hugging Face repository Args: @@ -92,16 +103,28 @@ def gradio_rag_deployment() -> None: exist_ok=True, token=get_hf_token(), ) - api.add_space_secret( - repo_id=hf_repo_id, - key="ZENML_STORE_API_KEY", - value=ZENML_API_TOKEN, - ) - api.add_space_secret( - repo_id=hf_repo_id, - key="ZENML_STORE_URL", - value=ZENML_STORE_URL, - ) + + # Ensure values are strings + if ZENML_API_TOKEN is not None: + api.add_space_secret( + repo_id=hf_repo_id, + key="ZENML_STORE_API_KEY", + value=str(ZENML_API_TOKEN), + ) + + if ZENML_STORE_URL is not None: + api.add_space_secret( + repo_id=hf_repo_id, + key="ZENML_STORE_URL", + value=str(ZENML_STORE_URL), + ) + + if SECRET_NAME is not None: + api.add_space_secret( + repo_id=hf_repo_id, + key="ZENML_PROJECT_SECRET_NAME", + value=str(SECRET_NAME), + ) files_to_upload = { "deployment_hf.py": "app.py", diff --git a/llm-complete-guide/steps/url_scraper.py b/llm-complete-guide/steps/url_scraper.py index 9c54563b..6afec0c7 100644 --- a/llm-complete-guide/steps/url_scraper.py +++ b/llm-complete-guide/steps/url_scraper.py @@ -16,7 +16,7 @@ import json from typing_extensions import Annotated -from zenml import ArtifactConfig, log_artifact_metadata, step +from zenml import ArtifactConfig, log_metadata, step from steps.url_scraping_utils import get_all_pages @@ -26,7 +26,7 @@ def url_scraper( docs_url: str = "https://docs.zenml.io", repo_url: str = "https://github.com/zenml-io/zenml", website_url: str = "https://zenml.io", - use_dev_set: bool = False + use_dev_set: bool = False, ) -> Annotated[str, ArtifactConfig(name="urls")]: """Generates a list of relevant URLs to scrape. @@ -40,9 +40,7 @@ def url_scraper( """ # We comment this out to make this pipeline faster # examples_readme_urls = get_nested_readme_urls(repo_url) - use_dev_set = False if use_dev_set: - docs_urls = [ "https://docs.zenml.io/getting-started/system-architectures", "https://docs.zenml.io/getting-started/core-concepts", @@ -58,10 +56,10 @@ def url_scraper( # website_urls = get_all_pages(website_url) # all_urls = docs_urls + website_urls + examples_readme_urls all_urls = docs_urls - log_artifact_metadata( - artifact_name="urls", + log_metadata( metadata={ "count": len(all_urls), }, + infer_artifact=True, ) return json.dumps(all_urls) diff --git a/llm-complete-guide/steps/url_scraping_utils.py b/llm-complete-guide/steps/url_scraping_utils.py index d6367cbf..ec97ac94 100644 --- a/llm-complete-guide/steps/url_scraping_utils.py +++ b/llm-complete-guide/steps/url_scraping_utils.py @@ -13,14 +13,15 @@ # permissions and limitations under the License. import re -import requests -from bs4 import BeautifulSoup -from typing import List from logging import getLogger +from typing import List +import requests +from bs4 import BeautifulSoup logger = getLogger(__name__) + def get_all_pages(base_url: str = "https://docs.zenml.io") -> List[str]: """ Retrieve all pages from the ZenML documentation sitemap. @@ -32,18 +33,19 @@ def get_all_pages(base_url: str = "https://docs.zenml.io") -> List[str]: List[str]: A list of all documentation page URLs. """ logger.info("Fetching sitemap from docs.zenml.io...") - + # Fetch the sitemap sitemap_url = f"{base_url}/sitemap.xml" response = requests.get(sitemap_url) soup = BeautifulSoup(response.text, "xml") - + # Extract all URLs from the sitemap urls = [loc.text for loc in soup.find_all("loc")] - + logger.info(f"Found {len(urls)} pages in the sitemap.") return urls + def extract_parent_section(url: str) -> str: """ Extracts the parent section from a URL. diff --git a/llm-complete-guide/utils/llm_utils.py b/llm-complete-guide/utils/llm_utils.py index 07516100..7ea75a31 100644 --- a/llm-complete-guide/utils/llm_utils.py +++ b/llm-complete-guide/utils/llm_utils.py @@ -20,6 +20,7 @@ # https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py import logging +import os from elasticsearch import Elasticsearch from zenml.client import Client @@ -48,7 +49,8 @@ OPENAI_MODEL, SECRET_NAME, SECRET_NAME_ELASTICSEARCH, - ZENML_CHATBOT_MODEL, + ZENML_CHATBOT_MODEL_NAME, + ZENML_CHATBOT_MODEL_VERSION, ) from pgvector.psycopg2 import register_vector from psycopg2.extensions import connection @@ -230,8 +232,12 @@ def get_es_client() -> Elasticsearch: Elasticsearch: An Elasticsearch client. """ client = Client() - es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"] - es_api_key = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_api_key"] + es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values[ + "elasticsearch_host" + ] + es_api_key = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values[ + "elasticsearch_api_key" + ] es = Elasticsearch( es_host, @@ -249,28 +255,35 @@ def get_db_conn() -> connection: Returns: connection: A psycopg2 connection object to the PostgreSQL database. """ - client = Client() - CONNECTION_DETAILS = { - "user": client.get_secret(SECRET_NAME).secret_values["supabase_user"], - "password": client.get_secret(SECRET_NAME).secret_values[ - "supabase_password" - ], - "host": client.get_secret(SECRET_NAME).secret_values["supabase_host"], - "port": client.get_secret(SECRET_NAME).secret_values["supabase_port"], - "dbname": "postgres", - } - - return psycopg2.connect(**CONNECTION_DETAILS) + try: + secret = client.get_secret(SECRET_NAME) + logger.debug(f"Secret keys: {list(secret.secret_values.keys())}") + + CONNECTION_DETAILS = { + "user": os.getenv("SUPABASE_USER") + or secret.secret_values["supabase_user"], + "password": os.getenv("SUPABASE_PASSWORD") + or secret.secret_values["supabase_password"], + "host": os.getenv("SUPABASE_HOST") + or secret.secret_values["supabase_host"], + "port": os.getenv("SUPABASE_PORT") + or secret.secret_values["supabase_port"], + "dbname": "postgres", + } + return psycopg2.connect(**CONNECTION_DETAILS) + except KeyError as e: + logger.error(f"Missing key in secret: {e}") + raise def get_topn_similar_docs_pgvector( - query_embedding: List[float], - conn: psycopg2.extensions.connection, - n: int = 5, - include_metadata: bool = False, - only_urls: bool = False - ) -> List[Tuple]: + query_embedding: List[float], + conn: psycopg2.extensions.connection, + n: int = 5, + include_metadata: bool = False, + only_urls: bool = False, +) -> List[Tuple]: """Fetches the top n most similar documents to the given query embedding from the PostgreSQL database. Args: @@ -302,13 +315,14 @@ def get_topn_similar_docs_pgvector( return cur.fetchall() + def get_topn_similar_docs_elasticsearch( - query_embedding: List[float], - es_client: Elasticsearch, - n: int = 5, - include_metadata: bool = False, - only_urls: bool = False - ) -> List[Tuple]: + query_embedding: List[float], + es_client: Elasticsearch, + n: int = 5, + include_metadata: bool = False, + only_urls: bool = False, +) -> List[Tuple]: """Fetches the top n most similar documents to the given query embedding from the Elasticsearch index. Args: @@ -319,7 +333,7 @@ def get_topn_similar_docs_elasticsearch( only_urls (bool, optional): Whether to only return URLs in the results. Defaults to False. """ index_name = "zenml_docs" - + if only_urls: source = ["url"] elif include_metadata: @@ -334,36 +348,42 @@ def get_topn_similar_docs_elasticsearch( "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", - "params": {"query_vector": query_embedding} - } + "params": {"query_vector": query_embedding}, + }, } }, - "size": n + "size": n, } # response = es_client.search(index=index_name, body=query) - response = es_client.search(index=index_name, knn={ - "field": "embedding", - "query_vector": query_embedding, - "num_candidates": 50, - "k": n - }) + response = es_client.search( + index=index_name, + knn={ + "field": "embedding", + "query_vector": query_embedding, + "num_candidates": 50, + "k": n, + }, + ) results = [] - for hit in response['hits']['hits']: + for hit in response["hits"]["hits"]: if only_urls: - results.append((hit['_source']['url'],)) + results.append((hit["_source"]["url"],)) elif include_metadata: - results.append(( - hit['_source']['content'], - hit['_source']['url'], - hit['_source']['parent_section'] - )) + results.append( + ( + hit["_source"]["content"], + hit["_source"]["url"], + hit["_source"]["parent_section"], + ) + ) else: - results.append((hit['_source']['content'],)) + results.append((hit["_source"]["content"],)) return results + def get_topn_similar_docs( query_embedding: List[float], conn: psycopg2.extensions.connection = None, @@ -387,15 +407,20 @@ def get_topn_similar_docs( """ if conn is None and es_client is None: raise ValueError("Either conn or es_client must be provided") - + if conn is not None: - return get_topn_similar_docs_pgvector(query_embedding, conn, n, include_metadata, only_urls) - + return get_topn_similar_docs_pgvector( + query_embedding, conn, n, include_metadata, only_urls + ) + if es_client is not None: - return get_topn_similar_docs_elasticsearch(query_embedding, es_client, n, include_metadata, only_urls) + return get_topn_similar_docs_elasticsearch( + query_embedding, es_client, n, include_metadata, only_urls + ) + def get_completion_from_messages( - messages, model=OPENAI_MODEL, temperature=0.4, max_tokens=1000 + messages, model=OPENAI_MODEL, temperature=0, max_tokens=1000 ): """Generates a completion response from the given messages using the specified model. @@ -431,17 +456,21 @@ def get_embeddings(text): model = SentenceTransformer(EMBEDDINGS_MODEL) return model.encode(text) -def find_vectorstore_name() -> str: - """Finds the name of the vector store used for the given embeddings model. - Returns: - str: The name of the vector store. - """ +def find_vectorstore_name() -> str: + """Finds the name of the vector store used for the given embeddings model.""" from zenml.client import Client - client = Client() - model = client.get_model_version(ZENML_CHATBOT_MODEL, model_version_name_or_number_or_id="v0.68.1-dev") - return model.run_metadata["vector_store"].value["name"] + client = Client() + try: + model_version = client.get_model_version( + model_name_or_id=ZENML_CHATBOT_MODEL_NAME, + model_version_name_or_number_or_id=ZENML_CHATBOT_MODEL_VERSION, + ) + return model_version.run_metadata["vector_store"]["name"] + except KeyError: + logger.error("Vector store metadata not found in model version") + return "pgvector" # Fallback to default def rerank_documents(