2525
2626import pinecone
2727from elasticsearch import Elasticsearch
28- from pinecone import Pinecone
28+ from pinecone import Pinecone , ServerlessSpec
2929from zenml .client import Client
3030
3131from utils .openai_utils import get_openai_api_key
32- import pinecone
33- from pinecone import Pinecone , ServerlessSpec
32+
3433# Configure logging levels for specific modules
3534logging .getLogger ("pytorch" ).setLevel (logging .CRITICAL )
3635logging .getLogger ("sentence-transformers" ).setLevel (logging .CRITICAL )
@@ -286,14 +285,18 @@ def get_db_conn() -> connection:
286285 raise
287286
288287
289- def get_pinecone_client (model_version_stage : str = "staging" ) -> pinecone .Index :
288+ def get_pinecone_client (
289+ model_version_stage : str = "staging" ,
290+ ) -> pinecone .Index :
290291 """Get a Pinecone index client.
291292
292293 Returns:
293294 pinecone.Index: A Pinecone index client.
294295 """
295296 client = Client ()
296- pinecone_api_key = client .get_secret (SECRET_NAME_PINECONE ).secret_values ["pinecone_api_key" ]
297+ pinecone_api_key = client .get_secret (SECRET_NAME_PINECONE ).secret_values [
298+ "pinecone_api_key"
299+ ]
297300 pc = Pinecone (api_key = pinecone_api_key )
298301
299302 # if the model versio is staging, we check if any index name is associated as metadata
@@ -307,35 +310,44 @@ def get_pinecone_client(model_version_stage: str = "staging") -> pinecone.Index:
307310
308311 if model_version_stage == "staging" :
309312 try :
310- index_name = model_version .run_metadata ["vector_store" ]["index_name" ]
313+ index_name = model_version .run_metadata ["vector_store" ][
314+ "index_name"
315+ ]
311316 except KeyError :
312- index_name = client .get_secret (SECRET_NAME_PINECONE ).secret_values .get ("pinecone_index" , "zenml-docs-dev" )
317+ index_name = client .get_secret (
318+ SECRET_NAME_PINECONE
319+ ).secret_values .get ("pinecone_index" , "zenml-docs-dev" )
313320 # if index by that name exists already, create a new one with a random suffix
314321 if index_name in pc .list_indexes ().names ():
315322 index_name = f"{ index_name } -{ uuid .uuid4 ()} "
316- model_version .run_metadata ["vector_store" ]["index_name" ] = index_name
323+ model_version .run_metadata ["vector_store" ]["index_name" ] = (
324+ index_name
325+ )
317326
318327 # Create index if it doesn't exist
319328 if index_name not in pc .list_indexes ().names ():
320329 pc .create_index (
321330 name = index_name ,
322331 dimension = EMBEDDING_DIMENSIONALITY ,
323332 metric = "cosine" ,
324- spec = ServerlessSpec (
325- cloud = "aws" ,
326- region = "us-east-1"
327- )
333+ spec = ServerlessSpec (cloud = "aws" , region = "us-east-1" ),
328334 )
329335
330336 if model_version_stage == "production" :
331337 try :
332- index_name = model_version .run_metadata ["vector_store" ]["index_name" ]
338+ index_name = model_version .run_metadata ["vector_store" ][
339+ "index_name"
340+ ]
333341 except KeyError :
334- raise ValueError ("The production model version should have an index name attached to it. None found." )
335-
342+ raise ValueError (
343+ "The production model version should have an index name attached to it. None found."
344+ )
345+
336346 # if index doesn't exist, raise error
337347 if index_name not in pc .list_indexes ().names ():
338- raise ValueError (f"The index { index_name } attached to the production model version does not exist. Please create it first." )
348+ raise ValueError (
349+ f"The index { index_name } attached to the production model version does not exist. Please create it first."
350+ )
339351
340352 return pc .Index (index_name )
341353
@@ -469,7 +481,7 @@ def get_topn_similar_docs_pinecone(
469481 # Convert numpy array to list if needed
470482 if isinstance (query_embedding , np .ndarray ):
471483 query_embedding = query_embedding .tolist ()
472-
484+
473485 # Query the index
474486 results = pinecone_index .query (
475487 vector = query_embedding , top_k = n , include_metadata = True
@@ -667,7 +679,9 @@ def process_input_with_retrieval(
667679 include_metadata = True ,
668680 )
669681 elif vector_store == "pinecone" :
670- pinecone_index = get_pinecone_client (model_version_stage = model_version_stage )
682+ pinecone_index = get_pinecone_client (
683+ model_version_stage = model_version_stage
684+ )
671685 similar_docs = get_topn_similar_docs (
672686 query_embedding = query_embedding ,
673687 pinecone_index = pinecone_index ,
0 commit comments