Skip to content

Commit 3edd521

Browse files
committed
merge
1 parent 2c224da commit 3edd521

File tree

6 files changed

+37
-45
lines changed

6 files changed

+37
-45
lines changed

llm-complete-guide/pipelines/llm_basic_rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from zenml import pipeline
2626

2727

28-
@pipeline
28+
@pipeline(enable_cache=True)
2929
def llm_basic_rag() -> None:
3030
"""Executes the pipeline to train a basic RAG model.
3131

llm-complete-guide/pipelines/llm_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from zenml import pipeline
2929

3030

31-
@pipeline(enable_cache=False)
31+
@pipeline(enable_cache=True)
3232
def llm_eval(after: Optional[str] = None) -> None:
3333
"""Executes the pipeline to evaluate a RAG pipeline."""
3434
# Retrieval evals

llm-complete-guide/steps/eval_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def query_similar_docs(
9797
conn = get_db_conn()
9898
elif vector_store_name == "pinecone":
9999
# in pipeline runs, always use staging index
100-
pinecone_index = get_pinecone_client(model_version_stage="staging")
100+
pinecone_index = get_pinecone_client(model_version_name_or_id="staging")
101101
else:
102102
es_client = get_es_client()
103103

llm-complete-guide/steps/populate_index.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,13 +648,15 @@ def index_generator(
648648
"""
649649
# get model version
650650
context = get_step_context()
651-
model_version_stage = context.model_version.stage
651+
model_version_name_or_id = context.model_version.name
652+
if context.model_version.stage == "production":
653+
model_version_name_or_id = "production"
652654
if index_type == IndexType.ELASTICSEARCH:
653655
_index_generator_elastic(documents)
654656
elif index_type == IndexType.POSTGRES:
655657
_index_generator_postgres(documents)
656658
elif index_type == IndexType.PINECONE:
657-
_index_generator_pinecone(documents, model_version_stage)
659+
_index_generator_pinecone(documents, model_version_name_or_id)
658660
else:
659661
raise ValueError(f"Unknown index type: {index_type}")
660662

@@ -829,16 +831,14 @@ def _index_generator_postgres(documents: str) -> None:
829831
conn.close()
830832

831833

832-
def _index_generator_pinecone(
833-
documents: str, model_version_stage: str
834-
) -> None:
834+
def _index_generator_pinecone(documents: str, model_version_name_or_id: str) -> None:
835835
"""Generates a Pinecone index for the given documents.
836836
837837
Args:
838838
documents (str): JSON string containing the documents to index.
839839
model_version (str): Name of the model version.
840840
"""
841-
index = get_pinecone_client(model_version_stage=model_version_stage)
841+
index = get_pinecone_client(model_version_name_or_id=model_version_name_or_id)
842842

843843
# Load documents
844844
docs = json.loads(documents)

llm-complete-guide/steps/url_scraper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def url_scraper(
5353
else:
5454
docs_urls = get_all_pages(docs_url)
5555

56-
# website_urls = get_all_pages(website_url)
56+
website_urls = get_all_pages(website_url)
5757
# all_urls = docs_urls + website_urls + examples_readme_urls
58-
all_urls = docs_urls
58+
# all_urls = website_urls
59+
all_urls = ["https://zenml.io"]
5960
log_metadata(
6061
metadata={
6162
"count": len(all_urls),

llm-complete-guide/utils/llm_utils.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,7 @@ def get_db_conn() -> connection:
285285
raise
286286

287287

288-
def get_pinecone_client(
289-
model_version_stage: str = "staging",
290-
) -> pinecone.Index:
288+
def get_pinecone_client(model_version_name_or_id: str = "dev") -> pinecone.Index:
291289
"""Get a Pinecone index client.
292290
293291
Returns:
@@ -305,24 +303,35 @@ def get_pinecone_client(
305303
# raise error if there is no index name attached to the metadata
306304
model_version = client.get_model_version(
307305
model_name_or_id=ZENML_CHATBOT_MODEL_NAME,
308-
model_version_name_or_number_or_id=model_version_stage,
306+
model_version_name_or_number_or_id=model_version_name_or_id,
309307
)
310308

311-
if model_version_stage == "staging":
309+
index_name_from_secret = client.get_secret(SECRET_NAME_PINECONE).secret_values.get("pinecone_index", "zenml-docs")
310+
311+
if model_version_name_or_id == "production":
312+
index_name = f"{index_name_from_secret}-prod"
313+
314+
model_version.run_metadata["vector_store"]["index_name"] = index_name
315+
316+
# delete index if it exists
317+
if index_name in pc.list_indexes().names():
318+
pc.delete_index(index_name)
319+
320+
# create index
321+
pc.create_index(
322+
name=index_name,
323+
dimension=EMBEDDING_DIMENSIONALITY,
324+
metric="cosine",
325+
spec=ServerlessSpec(cloud="aws", region="us-east-1")
326+
)
327+
else:
312328
try:
313329
index_name = model_version.run_metadata["vector_store"][
314330
"index_name"
315331
]
316332
except KeyError:
317-
index_name = client.get_secret(
318-
SECRET_NAME_PINECONE
319-
).secret_values.get("pinecone_index", "zenml-docs-dev")
320-
# if index by that name exists already, create a new one with a random suffix
321-
if index_name in pc.list_indexes().names():
322-
index_name = f"{index_name}-{uuid.uuid4()}"
323-
model_version.run_metadata["vector_store"]["index_name"] = (
324-
index_name
325-
)
333+
index_name = index_name_from_secret
334+
model_version.run_metadata["vector_store"]["index_name"] = index_name
326335

327336
# Create index if it doesn't exist
328337
if index_name not in pc.list_indexes().names():
@@ -332,23 +341,7 @@ def get_pinecone_client(
332341
metric="cosine",
333342
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
334343
)
335-
336-
if model_version_stage == "production":
337-
try:
338-
index_name = model_version.run_metadata["vector_store"][
339-
"index_name"
340-
]
341-
except KeyError:
342-
raise ValueError(
343-
"The production model version should have an index name attached to it. None found."
344-
)
345-
346-
# if index doesn't exist, raise error
347-
if index_name not in pc.list_indexes().names():
348-
raise ValueError(
349-
f"The index {index_name} attached to the production model version does not exist. Please create it first."
350-
)
351-
344+
352345
return pc.Index(index_name)
353346

354347

@@ -679,9 +672,7 @@ def process_input_with_retrieval(
679672
include_metadata=True,
680673
)
681674
elif vector_store == "pinecone":
682-
pinecone_index = get_pinecone_client(
683-
model_version_stage=model_version_stage
684-
)
675+
pinecone_index = get_pinecone_client(model_version_name_or_id=model_version_stage)
685676
similar_docs = get_topn_similar_docs(
686677
query_embedding=query_embedding,
687678
pinecone_index=pinecone_index,

0 commit comments

Comments
 (0)