Skip to content

Commit 83158e8

Browse files
committed
check model metadata to figure out what vectorstore to use
1 parent b49d425 commit 83158e8

File tree

2 files changed

+115
-17
lines changed

2 files changed

+115
-17
lines changed

llm-complete-guide/steps/eval_retrieval.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from datasets import load_dataset
2121
from utils.llm_utils import (
22+
find_vectorstore_name,
2223
get_db_conn,
2324
get_embeddings,
2425
get_es_client,
@@ -77,11 +78,23 @@ def query_similar_docs(
7778
Tuple containing the question, URL ending, and retrieved URLs.
7879
"""
7980
embedded_question = get_embeddings(question)
80-
es_client = get_es_client()
81+
conn = None
82+
es_client = None
83+
84+
vector_store_name = find_vectorstore_name()
85+
if vector_store_name == "pgvector":
86+
conn = get_db_conn()
87+
else:
88+
es_client = get_es_client()
89+
8190
num_docs = 20 if use_reranking else returned_sample_size
8291
# get (content, url) tuples for the top n similar documents
8392
top_similar_docs = get_topn_similar_docs(
84-
embedded_question, es_client, n=num_docs, include_metadata=True
93+
embedded_question,
94+
conn=conn,
95+
es_client=es_client,
96+
n=num_docs,
97+
include_metadata=True
8598
)
8699

87100
if use_reranking:

llm-complete-guide/utils/llm_utils.py

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
OPENAI_MODEL,
4949
SECRET_NAME,
5050
SECRET_NAME_ELASTICSEARCH,
51+
ZENML_CHATBOT_MODEL,
5152
)
5253
from pgvector.psycopg2 import register_vector
5354
from psycopg2.extensions import connection
@@ -263,25 +264,59 @@ def get_db_conn() -> connection:
263264
return psycopg2.connect(**CONNECTION_DETAILS)
264265

265266

266-
def get_topn_similar_docs(
267-
query_embedding: List[float],
268-
es_client: Elasticsearch,
269-
n: int = 5,
270-
include_metadata: bool = False,
271-
only_urls: bool = False,
272-
) -> List[Tuple]:
273-
"""Fetches the top n most similar documents to the given query embedding from the database.
267+
def get_topn_similar_docs_pgvector(
268+
query_embedding: List[float],
269+
conn: psycopg2.extensions.connection,
270+
n: int = 5,
271+
include_metadata: bool = False,
272+
only_urls: bool = False
273+
) -> List[Tuple]:
274+
"""Fetches the top n most similar documents to the given query embedding from the PostgreSQL database.
274275
275276
Args:
276277
query_embedding (list): The query embedding to compare against.
277278
conn (psycopg2.extensions.connection): The database connection object.
278-
n (int, optional): The number of similar documents to fetch. Defaults to
279-
5.
280-
include_metadata (bool, optional): Whether to include metadata in the
281-
results. Defaults to False.
279+
n (int, optional): The number of similar documents to fetch. Defaults to 5.
280+
include_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
281+
only_urls (bool, optional): Whether to only return URLs in the results. Defaults to False.
282+
"""
283+
embedding_array = np.array(query_embedding)
284+
register_vector(conn)
285+
cur = conn.cursor()
286+
287+
if include_metadata:
288+
cur.execute(
289+
f"SELECT content, url, parent_section FROM embeddings ORDER BY embedding <=> %s LIMIT {n}",
290+
(embedding_array,),
291+
)
292+
elif only_urls:
293+
cur.execute(
294+
f"SELECT url FROM embeddings ORDER BY embedding <=> %s LIMIT {n}",
295+
(embedding_array,),
296+
)
297+
else:
298+
cur.execute(
299+
f"SELECT content FROM embeddings ORDER BY embedding <=> %s LIMIT {n}",
300+
(embedding_array,),
301+
)
282302

283-
Returns:
284-
list: A list of tuples containing the content and metadata (if include_metadata is True) of the top n most similar documents.
303+
return cur.fetchall()
304+
305+
def get_topn_similar_docs_elasticsearch(
306+
query_embedding: List[float],
307+
es_client: Elasticsearch,
308+
n: int = 5,
309+
include_metadata: bool = False,
310+
only_urls: bool = False
311+
) -> List[Tuple]:
312+
"""Fetches the top n most similar documents to the given query embedding from the Elasticsearch index.
313+
314+
Args:
315+
query_embedding (list): The query embedding to compare against.
316+
es_client (Elasticsearch): The Elasticsearch client.
317+
n (int, optional): The number of similar documents to fetch. Defaults to 5.
318+
include_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
319+
only_urls (bool, optional): Whether to only return URLs in the results. Defaults to False.
285320
"""
286321
index_name = "zenml_docs"
287322

@@ -329,6 +364,35 @@ def get_topn_similar_docs(
329364

330365
return results
331366

367+
def get_topn_similar_docs(
368+
query_embedding: List[float],
369+
conn: psycopg2.extensions.connection = None,
370+
es_client: Elasticsearch = None,
371+
n: int = 5,
372+
include_metadata: bool = False,
373+
only_urls: bool = False,
374+
) -> List[Tuple]:
375+
"""Fetches the top n most similar documents to the given query embedding from the database.
376+
377+
Args:
378+
query_embedding (list): The query embedding to compare against.
379+
conn (psycopg2.extensions.connection): The database connection object.
380+
n (int, optional): The number of similar documents to fetch. Defaults to
381+
5.
382+
include_metadata (bool, optional): Whether to include metadata in the
383+
results. Defaults to False.
384+
385+
Returns:
386+
list: A list of tuples containing the content and metadata (if include_metadata is True) of the top n most similar documents.
387+
"""
388+
if conn is None and es_client is None:
389+
raise ValueError("Either conn or es_client must be provided")
390+
391+
if conn is not None:
392+
return get_topn_similar_docs_pgvector(query_embedding, conn, n, include_metadata, only_urls)
393+
394+
if es_client is not None:
395+
return get_topn_similar_docs_elasticsearch(query_embedding, es_client, n, include_metadata, only_urls)
332396

333397
def get_completion_from_messages(
334398
messages, model=OPENAI_MODEL, temperature=0.4, max_tokens=1000
@@ -367,6 +431,18 @@ def get_embeddings(text):
367431
model = SentenceTransformer(EMBEDDINGS_MODEL)
368432
return model.encode(text)
369433

434+
def find_vectorstore_name() -> str:
435+
"""Finds the name of the vector store used for the given embeddings model.
436+
437+
Returns:
438+
str: The name of the vector store.
439+
"""
440+
from zenml.client import Client
441+
client = Client()
442+
model = client.get_model_version(ZENML_CHATBOT_MODEL, model_version_name_or_number_or_id="v0.68.1-dev")
443+
444+
return model.run_metadata["vector_store"].value["name"]
445+
370446

371447
def rerank_documents(
372448
query: str, documents: List[Tuple], reranker_model: str = "flashrank"
@@ -420,11 +496,20 @@ def process_input_with_retrieval(
420496
str: The processed output.
421497
"""
422498
delimiter = "```"
499+
es_client = None
500+
conn = None
501+
502+
vector_store_name = find_vectorstore_name()
503+
if vector_store_name == "pgvector":
504+
conn = get_db_conn()
505+
else:
506+
es_client = get_es_client()
423507

424508
# Step 1: Get documents related to the user input from database
425509
related_docs = get_topn_similar_docs(
426510
get_embeddings(input),
427-
get_es_client(),
511+
conn=conn,
512+
es_client=es_client,
428513
n=n_items_retrieved,
429514
include_metadata=use_reranking,
430515
)

0 commit comments

Comments
 (0)