2727
2828from utils .openai_utils import get_openai_api_key
2929import pinecone
30- from pinecone import Pinecone
30+ from pinecone import Pinecone , ServerlessSpec
3131# Configure logging levels for specific modules
3232logging .getLogger ("pytorch" ).setLevel (logging .CRITICAL )
3333logging .getLogger ("sentence-transformers" ).setLevel (logging .CRITICAL )
4545import psycopg2
4646import tiktoken
4747from constants import (
48+ EMBEDDING_DIMENSIONALITY ,
4849 EMBEDDINGS_MODEL ,
4950 MODEL_NAME_MAP ,
5051 OPENAI_MODEL ,
@@ -279,17 +280,54 @@ def get_db_conn() -> connection:
279280 raise
280281
281282
282- def get_pinecone_client () -> pinecone .Index :
283+ def get_pinecone_client (model_version_stage : str = "staging" ) -> pinecone .Index :
283284 """Get a Pinecone index client.
284285
285286 Returns:
286287 pinecone.Index: A Pinecone index client.
287288 """
288289 client = Client ()
289290 pinecone_api_key = client .get_secret (SECRET_NAME_PINECONE ).secret_values ["pinecone_api_key" ]
290- index_name = client .get_secret (SECRET_NAME_PINECONE ).secret_values .get ("pinecone_index" , "zenml-docs" )
291-
292291 pc = Pinecone (api_key = pinecone_api_key )
292+
293+ # if the model versio is staging, we check if any index name is associated as metadata
294+ # if not, create a new one with the name from the secret and attach it to the metadata
295+ # if the model version is production, we just use the index name from the metadata attached to it
296+ # raise error if there is no index name attached to the metadata
297+ model_version = client .get_model_version (
298+ model_name_or_id = ZENML_CHATBOT_MODEL_NAME ,
299+ model_version_name_or_number_or_id = model_version_stage ,
300+ )
301+
302+ if model_version_stage == "staging" :
303+ try :
304+ index_name = model_version .run_metadata ["vector_store" ]["index_name" ]
305+ except KeyError :
306+ index_name = client .get_secret (SECRET_NAME_PINECONE ).secret_values .get ("pinecone_index" , "zenml-docs-dev" )
307+ model_version .run_metadata ["vector_store" ]["index_name" ] = index_name
308+
309+ # Create index if it doesn't exist
310+ if index_name not in pc .list_indexes ().names ():
311+ pc .create_index (
312+ name = index_name ,
313+ dimension = EMBEDDING_DIMENSIONALITY ,
314+ metric = "cosine" ,
315+ spec = ServerlessSpec (
316+ cloud = "aws" ,
317+ region = "us-east-1"
318+ )
319+ )
320+
321+ if model_version_stage == "production" :
322+ try :
323+ index_name = model_version .run_metadata ["vector_store" ]["index_name" ]
324+ except KeyError :
325+ raise ValueError ("The production model version should have an index name attached to it. None found." )
326+
327+ # if index doesn't exist, raise error
328+ if index_name not in pc .list_indexes ().names ():
329+ raise ValueError (f"The index { index_name } attached to the production model version does not exist. Please create it first." )
330+
293331 return pc .Index (index_name )
294332
295333
@@ -579,6 +617,7 @@ def process_input_with_retrieval(
579617 model : str = OPENAI_MODEL ,
580618 n_items_retrieved : int = 20 ,
581619 use_reranking : bool = False ,
620+ model_version_stage : str = "staging" ,
582621) -> str :
583622 """Process the input with retrieval.
584623
@@ -590,7 +629,7 @@ def process_input_with_retrieval(
590629 the database. Defaults to 5.
591630 use_reranking (bool, optional): Whether to use reranking. Defaults to
592631 False.
593-
632+ model_version_stage (str, optional): The stage of the model version. Defaults to "staging".
594633 Returns:
595634 str: The processed output.
596635 """
@@ -609,7 +648,7 @@ def process_input_with_retrieval(
609648 include_metadata = True ,
610649 )
611650 elif vector_store == "pinecone" :
612- pinecone_index = get_pinecone_client ()
651+ pinecone_index = get_pinecone_client (model_version_stage = model_version_stage )
613652 similar_docs = get_topn_similar_docs (
614653 query_embedding = query_embedding ,
615654 pinecone_index = pinecone_index ,
0 commit comments