Skip to content

Commit 7d34356

Browse files
committed
Formatting
1 parent 84a5c29 commit 7d34356

File tree

3 files changed

+47
-24
lines changed

3 files changed

+47
-24
lines changed

llm-complete-guide/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def main(
236236
response = process_input_with_retrieval(
237237
query_text,
238238
model=model,
239-
use_reranking=use_reranker, model_version_stage="production",
239+
use_reranking=use_reranker,
240+
model_version_stage="production",
240241
tracing_tags=["cli", "dev"],
241242
)
242243
console = Console()

llm-complete-guide/steps/populate_index.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@
4646
)
4747
from pgvector.psycopg2 import register_vector
4848
from PIL import Image, ImageDraw, ImageFont
49-
from pinecone import Pinecone, ServerlessSpec
5049
from sentence_transformers import SentenceTransformer
5150
from structures import Document
52-
from utils.llm_utils import get_db_conn, get_es_client, get_pinecone_client, split_documents
51+
from utils.llm_utils import (
52+
get_db_conn,
53+
get_es_client,
54+
get_pinecone_client,
55+
split_documents,
56+
)
5357
from zenml import ArtifactConfig, get_step_context, log_metadata, step
5458
from zenml.client import Client
5559
from zenml.metadata.metadata_types import Uri
@@ -642,7 +646,7 @@ def index_generator(
642646
documents (str): JSON string containing the documents to index.
643647
index_type (IndexType, optional): Type of index to generate. Defaults to IndexType.POSTGRES.
644648
"""
645-
# get model version
649+
# get model version
646650
context = get_step_context()
647651
model_version_stage = context.model_version.stage
648652
if index_type == IndexType.ELASTICSEARCH:
@@ -825,7 +829,9 @@ def _index_generator_postgres(documents: str) -> None:
825829
conn.close()
826830

827831

828-
def _index_generator_pinecone(documents: str, model_version_stage: str) -> None:
832+
def _index_generator_pinecone(
833+
documents: str, model_version_stage: str
834+
) -> None:
829835
"""Generates a Pinecone index for the given documents.
830836
831837
Args:
@@ -870,7 +876,9 @@ def _index_generator_pinecone(documents: str, model_version_stage: str) -> None:
870876
if batch:
871877
index.upsert(vectors=batch)
872878

873-
logger.info(f"Successfully indexed {len(docs)} documents to Pinecone index")
879+
logger.info(
880+
f"Successfully indexed {len(docs)} documents to Pinecone index"
881+
)
874882

875883

876884
def _log_metadata(index_type: IndexType) -> None:

llm-complete-guide/utils/llm_utils.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525

2626
import pinecone
2727
from elasticsearch import Elasticsearch
28-
from pinecone import Pinecone
28+
from pinecone import Pinecone, ServerlessSpec
2929
from zenml.client import Client
3030

3131
from utils.openai_utils import get_openai_api_key
32-
import pinecone
33-
from pinecone import Pinecone, ServerlessSpec
32+
3433
# Configure logging levels for specific modules
3534
logging.getLogger("pytorch").setLevel(logging.CRITICAL)
3635
logging.getLogger("sentence-transformers").setLevel(logging.CRITICAL)
@@ -286,14 +285,18 @@ def get_db_conn() -> connection:
286285
raise
287286

288287

289-
def get_pinecone_client(model_version_stage: str = "staging") -> pinecone.Index:
288+
def get_pinecone_client(
289+
model_version_stage: str = "staging",
290+
) -> pinecone.Index:
290291
"""Get a Pinecone index client.
291292
292293
Returns:
293294
pinecone.Index: A Pinecone index client.
294295
"""
295296
client = Client()
296-
pinecone_api_key = client.get_secret(SECRET_NAME_PINECONE).secret_values["pinecone_api_key"]
297+
pinecone_api_key = client.get_secret(SECRET_NAME_PINECONE).secret_values[
298+
"pinecone_api_key"
299+
]
297300
pc = Pinecone(api_key=pinecone_api_key)
298301

299302
# if the model versio is staging, we check if any index name is associated as metadata
@@ -307,35 +310,44 @@ def get_pinecone_client(model_version_stage: str = "staging") -> pinecone.Index:
307310

308311
if model_version_stage == "staging":
309312
try:
310-
index_name = model_version.run_metadata["vector_store"]["index_name"]
313+
index_name = model_version.run_metadata["vector_store"][
314+
"index_name"
315+
]
311316
except KeyError:
312-
index_name = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs-dev")
317+
index_name = client.get_secret(
318+
SECRET_NAME_PINECONE
319+
).secret_values.get("pinecone_index", "zenml-docs-dev")
313320
# if index by that name exists already, create a new one with a random suffix
314321
if index_name in pc.list_indexes().names():
315322
index_name = f"{index_name}-{uuid.uuid4()}"
316-
model_version.run_metadata["vector_store"]["index_name"] = index_name
323+
model_version.run_metadata["vector_store"]["index_name"] = (
324+
index_name
325+
)
317326

318327
# Create index if it doesn't exist
319328
if index_name not in pc.list_indexes().names():
320329
pc.create_index(
321330
name=index_name,
322331
dimension=EMBEDDING_DIMENSIONALITY,
323332
metric="cosine",
324-
spec=ServerlessSpec(
325-
cloud="aws",
326-
region="us-east-1"
327-
)
333+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
328334
)
329335

330336
if model_version_stage == "production":
331337
try:
332-
index_name = model_version.run_metadata["vector_store"]["index_name"]
338+
index_name = model_version.run_metadata["vector_store"][
339+
"index_name"
340+
]
333341
except KeyError:
334-
raise ValueError("The production model version should have an index name attached to it. None found.")
335-
342+
raise ValueError(
343+
"The production model version should have an index name attached to it. None found."
344+
)
345+
336346
# if index doesn't exist, raise error
337347
if index_name not in pc.list_indexes().names():
338-
raise ValueError(f"The index {index_name} attached to the production model version does not exist. Please create it first.")
348+
raise ValueError(
349+
f"The index {index_name} attached to the production model version does not exist. Please create it first."
350+
)
339351

340352
return pc.Index(index_name)
341353

@@ -469,7 +481,7 @@ def get_topn_similar_docs_pinecone(
469481
# Convert numpy array to list if needed
470482
if isinstance(query_embedding, np.ndarray):
471483
query_embedding = query_embedding.tolist()
472-
484+
473485
# Query the index
474486
results = pinecone_index.query(
475487
vector=query_embedding, top_k=n, include_metadata=True
@@ -667,7 +679,9 @@ def process_input_with_retrieval(
667679
include_metadata=True,
668680
)
669681
elif vector_store == "pinecone":
670-
pinecone_index = get_pinecone_client(model_version_stage=model_version_stage)
682+
pinecone_index = get_pinecone_client(
683+
model_version_stage=model_version_stage
684+
)
671685
similar_docs = get_topn_similar_docs(
672686
query_embedding=query_embedding,
673687
pinecone_index=pinecone_index,

0 commit comments

Comments
 (0)