Skip to content

Commit 02c158a

Browse files
committed
add ability to switch bw elastic and pgvector
1 parent 83158e8 commit 02c158a

File tree

2 files changed

+155
-44
lines changed

2 files changed

+155
-44
lines changed

llm-complete-guide/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
384 # Update this to match the dimensionality of the new model
2424
)
2525

26+
# ZenML constants
27+
ZENML_CHATBOT_MODEL = "zenml-docs-qa-chatbot"
28+
2629
# Scraping constants
2730
RATE_LIMIT = 5 # Maximum number of requests per second
2831

llm-complete-guide/steps/populate_index.py

Lines changed: 152 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
# https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/
2020
# for providing the base implementation for this indexing functionality
2121

22+
import hashlib
2223
import json
2324
import logging
2425
import math
2526
from typing import Annotated, Any, Dict, List, Tuple
27+
from enum import Enum
2628

2729
from constants import (
2830
CHUNK_OVERLAP,
2931
CHUNK_SIZE,
3032
EMBEDDING_DIMENSIONALITY,
3133
EMBEDDINGS_MODEL,
3234
SECRET_NAME_ELASTICSEARCH,
35+
ZENML_CHATBOT_MODEL,
3336
)
3437
from pgvector.psycopg2 import register_vector
3538
from PIL import Image, ImageDraw, ImageFont
@@ -593,9 +596,14 @@ def generate_embeddings(
593596
raise
594597

595598

599+
class IndexType(Enum):
600+
ELASTICSEARCH = "elasticsearch"
601+
POSTGRES = "postgres"
602+
596603
@step(enable_cache=False)
597604
def index_generator(
598605
documents: str,
606+
index_type: IndexType = IndexType.ELASTICSEARCH,
599607
) -> None:
600608
"""Generates an index for the given documents.
601609
@@ -606,14 +614,23 @@ def index_generator(
606614
607615
Args:
608616
documents (str): A JSON string containing the Document objects with generated embeddings.
617+
index_type (IndexType): The type of index to use. Defaults to Elasticsearch.
609618
610619
Raises:
611620
Exception: If an error occurs during the index generation.
612621
"""
613-
from elasticsearch import Elasticsearch
614-
from elasticsearch.helpers import bulk
615-
import hashlib
616-
622+
try:
623+
if index_type == IndexType.ELASTICSEARCH:
624+
_index_generator_elastic(documents)
625+
else:
626+
_index_generator_postgres(documents)
627+
628+
except Exception as e:
629+
logger.error(f"Error in index_generator: {e}")
630+
raise
631+
632+
def _index_generator_elastic(documents: str) -> None:
633+
"""Generates an Elasticsearch index for the given documents."""
617634
try:
618635
es = get_es_client()
619636
index_name = "zenml_docs"
@@ -643,16 +660,13 @@ def index_generator(
643660

644661
# Parse the JSON string into a list of Document objects
645662
document_list = [Document(**doc) for doc in json.loads(documents)]
646-
647-
# Prepare bulk operations
648663
operations = []
664+
649665
for doc in document_list:
650-
# Create a unique identifier based on content and metadata
651666
content_hash = hashlib.md5(
652667
f"{doc.page_content}{doc.filename}{doc.parent_section}{doc.url}".encode()
653668
).hexdigest()
654669

655-
# Check if document exists
656670
exists_query = {
657671
"query": {
658672
"term": {
@@ -694,45 +708,139 @@ def index_generator(
694708
else:
695709
logger.info("No new documents to index")
696710

697-
# Log the model metadata
698-
prompt = """
699-
You are a friendly chatbot. \
700-
You can answer questions about ZenML, its features and its use cases. \
701-
You respond in a concise, technically credible tone. \
702-
You ONLY use the context from the ZenML documentation to provide relevant
703-
answers. \
704-
You do not make up answers or provide opinions that you don't have
705-
information to support. \
706-
If you are unsure or don't know, just say so. \
707-
"""
708-
709-
client = Client()
711+
_log_metadata(index_type=IndexType.ELASTICSEARCH)
712+
713+
except Exception as e:
714+
logger.error(f"Error in Elasticsearch indexing: {e}")
715+
raise
716+
717+
def _index_generator_postgres(documents: str) -> None:
718+
"""Generates a PostgreSQL index for the given documents."""
719+
try:
720+
conn = get_db_conn()
721+
722+
with conn.cursor() as cur:
723+
# Install pgvector if not already installed
724+
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
725+
conn.commit()
726+
727+
# Create the embeddings table if it doesn't exist
728+
table_create_command = f"""
729+
CREATE TABLE IF NOT EXISTS embeddings (
730+
id SERIAL PRIMARY KEY,
731+
content TEXT,
732+
token_count INTEGER,
733+
embedding VECTOR({EMBEDDING_DIMENSIONALITY}),
734+
filename TEXT,
735+
parent_section TEXT,
736+
url TEXT
737+
);
738+
"""
739+
cur.execute(table_create_command)
740+
conn.commit()
741+
742+
register_vector(conn)
743+
744+
# Parse the JSON string into a list of Document objects
745+
document_list = [Document(**doc) for doc in json.loads(documents)]
746+
747+
# Insert data only if it doesn't already exist
748+
for doc in document_list:
749+
content = doc.page_content
750+
token_count = doc.token_count
751+
embedding = doc.embedding
752+
filename = doc.filename
753+
parent_section = doc.parent_section
754+
url = doc.url
755+
756+
cur.execute(
757+
"SELECT COUNT(*) FROM embeddings WHERE content = %s",
758+
(content,),
759+
)
760+
count = cur.fetchone()[0]
761+
if count == 0:
762+
cur.execute(
763+
"INSERT INTO embeddings (content, token_count, embedding, filename, parent_section, url) VALUES (%s, %s, %s, %s, %s, %s)",
764+
(
765+
content,
766+
token_count,
767+
embedding,
768+
filename,
769+
parent_section,
770+
url,
771+
),
772+
)
773+
conn.commit()
774+
775+
776+
cur.execute("SELECT COUNT(*) as cnt FROM embeddings;")
777+
num_records = cur.fetchone()[0]
778+
logger.info(f"Number of vector records in table: {num_records}")
779+
780+
# calculate the index parameters according to best practices
781+
num_lists = max(num_records / 1000, 10)
782+
if num_records > 1000000:
783+
num_lists = math.sqrt(num_records)
784+
785+
# use the cosine distance measure, which is what we'll later use for querying
786+
cur.execute(
787+
f"CREATE INDEX IF NOT EXISTS embeddings_idx ON embeddings USING ivfflat (embedding vector_cosine_ops) WITH (lists = {num_lists});"
788+
)
789+
conn.commit()
790+
791+
_log_metadata(index_type=IndexType.POSTGRES)
792+
793+
except Exception as e:
794+
logger.error(f"Error in PostgreSQL indexing: {e}")
795+
raise
796+
finally:
797+
if conn:
798+
conn.close()
799+
800+
def _log_metadata(index_type: IndexType) -> None:
801+
"""Log metadata about the indexing process."""
802+
prompt = """
803+
You are a friendly chatbot. \
804+
You can answer questions about ZenML, its features and its use cases. \
805+
You respond in a concise, technically credible tone. \
806+
You ONLY use the context from the ZenML documentation to provide relevant answers. \
807+
You do not make up answers or provide opinions that you don't have information to support. \
808+
If you are unsure or don't know, just say so. \
809+
"""
810+
811+
client = Client()
812+
813+
if index_type == IndexType.ELASTICSEARCH:
710814
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
711-
CONNECTION_DETAILS = {
815+
connection_details = {
712816
"host": es_host,
713817
"api_key": "*********",
714818
}
819+
store_name = "elasticsearch"
820+
else:
821+
store_name = "pgvector"
822+
823+
connection_details = {
824+
"user": client.get_secret(SECRET_NAME).secret_values["supabase_user"],
825+
"password": "**********",
826+
"host": client.get_secret(SECRET_NAME).secret_values["supabase_host"],
827+
"port": client.get_secret(SECRET_NAME).secret_values["supabase_port"],
828+
"dbname": "postgres",
829+
}
715830

716-
log_model_metadata(
717-
metadata={
718-
"embeddings": {
719-
"model": EMBEDDINGS_MODEL,
720-
"dimensionality": EMBEDDING_DIMENSIONALITY,
721-
"model_url": Uri(
722-
f"https://huggingface.co/{EMBEDDINGS_MODEL}"
723-
),
724-
},
725-
"prompt": {
726-
"content": prompt,
727-
},
728-
"vector_store": {
729-
"name": "elasticsearch",
730-
"connection_details": CONNECTION_DETAILS,
731-
"index_name": index_name
732-
},
831+
log_model_metadata(
832+
metadata={
833+
"embeddings": {
834+
"model": EMBEDDINGS_MODEL,
835+
"dimensionality": EMBEDDING_DIMENSIONALITY,
836+
"model_url": Uri(f"https://huggingface.co/{EMBEDDINGS_MODEL}"),
733837
},
734-
)
735-
736-
except Exception as e:
737-
logger.error(f"Error in index_generator: {e}")
738-
raise
838+
"prompt": {
839+
"content": prompt,
840+
},
841+
"vector_store": {
842+
"name": store_name,
843+
"connection_details": connection_details,
844+
},
845+
},
846+
)

0 commit comments

Comments
 (0)