Skip to content

Commit c648381

Browse files
committed
use the right index for the model stage
1 parent b3f062e commit c648381

File tree

6 files changed

+63
-34
lines changed

6 files changed

+63
-34
lines changed

llm-complete-guide/deployment_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323

2424
def predict(message, history):
2525
try:
26+
# add the prod flag here
2627
return process_input_with_retrieval(
2728
input=message,
2829
n_items_retrieved=20,
2930
use_reranking=True,
31+
model_version_stage="production",
3032
)
3133
except Exception as e:
3234
logger.error(f"Error processing message: {e}")

llm-complete-guide/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def main(
232232
raise click.UsageError(
233233
"--query-text is required when using 'query' command"
234234
)
235+
# add the prod flag here
235236
response = process_input_with_retrieval(
236-
query_text, model=model, use_reranking=use_reranker
237+
query_text, model=model, use_reranking=use_reranker, model_version_stage="production"
237238
)
238239
console = Console()
239240
md = Markdown(response)

llm-complete-guide/steps/eval_retrieval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def query_similar_docs(
9696
if vector_store_name == "pgvector":
9797
conn = get_db_conn()
9898
elif vector_store_name == "pinecone":
99-
pinecone_index = get_pinecone_client()
99+
# in pipeline runs, always use staging index
100+
pinecone_index = get_pinecone_client(model_version_stage="staging")
100101
else:
101102
es_client = get_es_client()
102103

llm-complete-guide/steps/populate_index.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
from PIL import Image, ImageDraw, ImageFont
4949
from sentence_transformers import SentenceTransformer
5050
from structures import Document
51-
from utils.llm_utils import get_db_conn, get_es_client, split_documents
52-
from zenml import ArtifactConfig, log_metadata, step
51+
from utils.llm_utils import get_db_conn, get_es_client, get_pinecone_client, split_documents
52+
from zenml import ArtifactConfig, get_step_context, log_metadata, step
5353
from zenml.client import Client
5454
from zenml.metadata.metadata_types import Uri
5555
import pinecone
@@ -642,12 +642,15 @@ def index_generator(
642642
documents (str): JSON string containing the documents to index.
643643
index_type (IndexType, optional): Type of index to generate. Defaults to IndexType.POSTGRES.
644644
"""
645+
# get model version
646+
context = get_step_context()
647+
model_version_stage = context.model_version.stage
645648
if index_type == IndexType.ELASTICSEARCH:
646649
_index_generator_elastic(documents)
647650
elif index_type == IndexType.POSTGRES:
648651
_index_generator_postgres(documents)
649652
elif index_type == IndexType.PINECONE:
650-
_index_generator_pinecone(documents)
653+
_index_generator_pinecone(documents, model_version_stage)
651654
else:
652655
raise ValueError(f"Unknown index type: {index_type}")
653656

@@ -822,33 +825,14 @@ def _index_generator_postgres(documents: str) -> None:
822825
conn.close()
823826

824827

825-
def _index_generator_pinecone(documents: str) -> None:
828+
def _index_generator_pinecone(documents: str, model_version_stage: str) -> None:
826829
"""Generates a Pinecone index for the given documents.
827830
828831
Args:
829832
documents (str): JSON string containing the documents to index.
833+
model_version (str): Name of the model version.
830834
"""
831-
client = Client()
832-
pinecone_api_key = client.get_secret(SECRET_NAME_PINECONE).secret_values["pinecone_api_key"]
833-
index_name = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs")
834-
835-
# Initialize Pinecone
836-
pc = Pinecone(api_key=pinecone_api_key)
837-
838-
# Create index if it doesn't exist
839-
if index_name not in pc.list_indexes().names():
840-
pc.create_index(
841-
name=index_name,
842-
dimension=EMBEDDING_DIMENSIONALITY,
843-
metric="cosine",
844-
spec=ServerlessSpec(
845-
cloud="aws",
846-
region="us-east-1"
847-
)
848-
)
849-
850-
# Get the index
851-
index = pc.Index(index_name)
835+
index = get_pinecone_client(model_version_stage=model_version_stage)
852836

853837
# Load documents
854838
docs = json.loads(documents)
@@ -886,7 +870,7 @@ def _index_generator_pinecone(documents: str) -> None:
886870
if batch:
887871
index.upsert(vectors=batch)
888872

889-
logger.info(f"Successfully indexed {len(docs)} documents to Pinecone index '{index_name}'")
873+
logger.info(f"Successfully indexed {len(docs)} documents to Pinecone index")
890874

891875

892876
def _log_metadata(index_type: IndexType) -> None:

llm-complete-guide/steps/rag_deployment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@
5656

5757

5858
def predict(message, history):
59+
# add the prod flag here
5960
return process_input_with_retrieval(
6061
input=message,
6162
n_items_retrieved=20,
6263
use_reranking=True,
64+
model_version_stage="production",
6365
)
6466

6567

llm-complete-guide/utils/llm_utils.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from utils.openai_utils import get_openai_api_key
2929
import pinecone
30-
from pinecone import Pinecone
30+
from pinecone import Pinecone, ServerlessSpec
3131
# Configure logging levels for specific modules
3232
logging.getLogger("pytorch").setLevel(logging.CRITICAL)
3333
logging.getLogger("sentence-transformers").setLevel(logging.CRITICAL)
@@ -45,6 +45,7 @@
4545
import psycopg2
4646
import tiktoken
4747
from constants import (
48+
EMBEDDING_DIMENSIONALITY,
4849
EMBEDDINGS_MODEL,
4950
MODEL_NAME_MAP,
5051
OPENAI_MODEL,
@@ -279,17 +280,54 @@ def get_db_conn() -> connection:
279280
raise
280281

281282

282-
def get_pinecone_client() -> pinecone.Index:
283+
def get_pinecone_client(model_version_stage: str = "staging") -> pinecone.Index:
283284
"""Get a Pinecone index client.
284285
285286
Returns:
286287
pinecone.Index: A Pinecone index client.
287288
"""
288289
client = Client()
289290
pinecone_api_key = client.get_secret(SECRET_NAME_PINECONE).secret_values["pinecone_api_key"]
290-
index_name = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs")
291-
292291
pc = Pinecone(api_key=pinecone_api_key)
292+
293+
# if the model versio is staging, we check if any index name is associated as metadata
294+
# if not, create a new one with the name from the secret and attach it to the metadata
295+
# if the model version is production, we just use the index name from the metadata attached to it
296+
# raise error if there is no index name attached to the metadata
297+
model_version = client.get_model_version(
298+
model_name_or_id=ZENML_CHATBOT_MODEL_NAME,
299+
model_version_name_or_number_or_id=model_version_stage,
300+
)
301+
302+
if model_version_stage == "staging":
303+
try:
304+
index_name = model_version.run_metadata["vector_store"]["index_name"]
305+
except KeyError:
306+
index_name = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs-dev")
307+
model_version.run_metadata["vector_store"]["index_name"] = index_name
308+
309+
# Create index if it doesn't exist
310+
if index_name not in pc.list_indexes().names():
311+
pc.create_index(
312+
name=index_name,
313+
dimension=EMBEDDING_DIMENSIONALITY,
314+
metric="cosine",
315+
spec=ServerlessSpec(
316+
cloud="aws",
317+
region="us-east-1"
318+
)
319+
)
320+
321+
if model_version_stage == "production":
322+
try:
323+
index_name = model_version.run_metadata["vector_store"]["index_name"]
324+
except KeyError:
325+
raise ValueError("The production model version should have an index name attached to it. None found.")
326+
327+
# if index doesn't exist, raise error
328+
if index_name not in pc.list_indexes().names():
329+
raise ValueError(f"The index {index_name} attached to the production model version does not exist. Please create it first.")
330+
293331
return pc.Index(index_name)
294332

295333

@@ -579,6 +617,7 @@ def process_input_with_retrieval(
579617
model: str = OPENAI_MODEL,
580618
n_items_retrieved: int = 20,
581619
use_reranking: bool = False,
620+
model_version_stage: str = "staging",
582621
) -> str:
583622
"""Process the input with retrieval.
584623
@@ -590,7 +629,7 @@ def process_input_with_retrieval(
590629
the database. Defaults to 5.
591630
use_reranking (bool, optional): Whether to use reranking. Defaults to
592631
False.
593-
632+
model_version_stage (str, optional): The stage of the model version. Defaults to "staging".
594633
Returns:
595634
str: The processed output.
596635
"""
@@ -609,7 +648,7 @@ def process_input_with_retrieval(
609648
include_metadata=True,
610649
)
611650
elif vector_store == "pinecone":
612-
pinecone_index = get_pinecone_client()
651+
pinecone_index = get_pinecone_client(model_version_stage=model_version_stage)
613652
similar_docs = get_topn_similar_docs(
614653
query_embedding=query_embedding,
615654
pinecone_index=pinecone_index,

0 commit comments

Comments
 (0)