@@ -285,9 +285,7 @@ def get_db_conn() -> connection:
285285 raise
286286
287287
288- def get_pinecone_client (
289- model_version_stage : str = "staging" ,
290- ) -> pinecone .Index :
288+ def get_pinecone_client (model_version_name_or_id : str = "dev" ) -> pinecone .Index :
291289 """Get a Pinecone index client.
292290
293291 Returns:
@@ -305,24 +303,35 @@ def get_pinecone_client(
305303 # raise error if there is no index name attached to the metadata
306304 model_version = client .get_model_version (
307305 model_name_or_id = ZENML_CHATBOT_MODEL_NAME ,
308- model_version_name_or_number_or_id = model_version_stage ,
306+ model_version_name_or_number_or_id = model_version_name_or_id ,
309307 )
310308
311- if model_version_stage == "staging" :
309+ index_name_from_secret = client .get_secret (SECRET_NAME_PINECONE ).secret_values .get ("pinecone_index" , "zenml-docs" )
310+
311+ if model_version_name_or_id == "production" :
312+ index_name = f"{ index_name_from_secret } -prod"
313+
314+ model_version .run_metadata ["vector_store" ]["index_name" ] = index_name
315+
316+ # delete index if it exists
317+ if index_name in pc .list_indexes ().names ():
318+ pc .delete_index (index_name )
319+
320+ # create index
321+ pc .create_index (
322+ name = index_name ,
323+ dimension = EMBEDDING_DIMENSIONALITY ,
324+ metric = "cosine" ,
325+ spec = ServerlessSpec (cloud = "aws" , region = "us-east-1" )
326+ )
327+ else :
312328 try :
313329 index_name = model_version .run_metadata ["vector_store" ][
314330 "index_name"
315331 ]
316332 except KeyError :
317- index_name = client .get_secret (
318- SECRET_NAME_PINECONE
319- ).secret_values .get ("pinecone_index" , "zenml-docs-dev" )
320- # if index by that name exists already, create a new one with a random suffix
321- if index_name in pc .list_indexes ().names ():
322- index_name = f"{ index_name } -{ uuid .uuid4 ()} "
323- model_version .run_metadata ["vector_store" ]["index_name" ] = (
324- index_name
325- )
333+ index_name = index_name_from_secret
334+ model_version .run_metadata ["vector_store" ]["index_name" ] = index_name
326335
327336 # Create index if it doesn't exist
328337 if index_name not in pc .list_indexes ().names ():
@@ -332,23 +341,7 @@ def get_pinecone_client(
332341 metric = "cosine" ,
333342 spec = ServerlessSpec (cloud = "aws" , region = "us-east-1" ),
334343 )
335-
336- if model_version_stage == "production" :
337- try :
338- index_name = model_version .run_metadata ["vector_store" ][
339- "index_name"
340- ]
341- except KeyError :
342- raise ValueError (
343- "The production model version should have an index name attached to it. None found."
344- )
345-
346- # if index doesn't exist, raise error
347- if index_name not in pc .list_indexes ().names ():
348- raise ValueError (
349- f"The index { index_name } attached to the production model version does not exist. Please create it first."
350- )
351-
344+
352345 return pc .Index (index_name )
353346
354347
@@ -679,9 +672,7 @@ def process_input_with_retrieval(
679672 include_metadata = True ,
680673 )
681674 elif vector_store == "pinecone" :
682- pinecone_index = get_pinecone_client (
683- model_version_stage = model_version_stage
684- )
675+ pinecone_index = get_pinecone_client (model_version_name_or_id = model_version_stage )
685676 similar_docs = get_topn_similar_docs (
686677 query_embedding = query_embedding ,
687678 pinecone_index = pinecone_index ,
0 commit comments