Skip to content

Commit 6fc5b8a

Browse files
committed
Reformatted and cleaned up deprecated code
1 parent a5c8be6 commit 6fc5b8a

24 files changed

+321
-275
lines changed

llm-complete-guide/gh_action_rag.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121

2222
import click
2323
import yaml
24-
from zenml.enums import PluginSubType
25-
2624
from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
27-
from zenml.client import Client
2825
from zenml import Model
29-
from zenml.exceptions import ZenKeyError
26+
from zenml.client import Client
27+
from zenml.enums import PluginSubType
3028

3129

3230
@click.command(
@@ -89,7 +87,7 @@ def main(
8987
zenml_model_name: Optional[str] = "zenml-docs-qa-rag",
9088
zenml_model_version: Optional[str] = None,
9189
):
92-
"""
90+
"""
9391
Executes the pipeline to train a basic RAG model.
9492
9593
Args:
@@ -108,14 +106,14 @@ def main(
108106
config = yaml.safe_load(file)
109107

110108
# Read the model version from a file in the root of the repo
111-
# called "ZENML_VERSION.txt".
109+
# called "ZENML_VERSION.txt".
112110
if zenml_model_version == "staging":
113111
postfix = "-rc0"
114112
elif zenml_model_version == "production":
115113
postfix = ""
116114
else:
117115
postfix = "-dev"
118-
116+
119117
if Path("ZENML_VERSION.txt").exists():
120118
with open("ZENML_VERSION.txt", "r") as file:
121119
zenml_model_version = file.read().strip()
@@ -177,7 +175,7 @@ def main(
177175
service_account_id=service_account_id,
178176
auth_window=0,
179177
flavor="builtin",
180-
action_type=PluginSubType.PIPELINE_RUN
178+
action_type=PluginSubType.PIPELINE_RUN,
181179
).id
182180
client.create_trigger(
183181
name="Production Trigger LLM-Complete",

llm-complete-guide/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pipelines.generate_chunk_questions import generate_chunk_questions
2020
from pipelines.llm_basic_rag import llm_basic_rag
2121
from pipelines.llm_eval import llm_eval
22-
from pipelines.rag_deployment import rag_deployment
2322
from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
2423
from pipelines.local_deployment import local_deployment
25-
from pipelines.prod_deployment import production_deployment
24+
from pipelines.prod_deployment import production_deployment
25+
from pipelines.rag_deployment import rag_deployment

llm-complete-guide/pipelines/finetune_embeddings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15-
from constants import EMBEDDINGS_MODEL_NAME_ZENML
1615
from steps.finetune_embeddings import (
1716
evaluate_base_model,
1817
evaluate_finetuned_model,

llm-complete-guide/pipelines/llm_basic_rag.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from litellm import config_path
1817

1918
from steps.populate_index import (
2019
generate_embeddings,

llm-complete-guide/pipelines/llm_index_and_evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# limitations under the License.
1616
#
1717

18-
from pipelines import llm_basic_rag, llm_eval
1918
from zenml import pipeline
2019

20+
from pipelines import llm_basic_rag, llm_eval
21+
2122

2223
@pipeline
2324
def llm_index_and_evaluate() -> None:

llm-complete-guide/pipelines/local_deployment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from steps.bento_builder import bento_builder
22
from steps.bento_deployment import bento_deployment
3-
from steps.visualize_chat import create_chat_interface
43
from zenml import pipeline
54

65

llm-complete-guide/pipelines/prod_deployment.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222

2323

2424
@pipeline(enable_cache=False)
25-
def production_deployment(
26-
):
25+
def production_deployment():
2726
"""Model deployment pipeline.
2827
2928
This is a pipeline deploys trained model for future inference.
3029
"""
3130
bento_model_image = bento_dockerizer()
3231
deployment_info = k8s_deployment(bento_model_image)
33-
create_chat_interface(deployment_info)
32+
create_chat_interface(deployment_info)

llm-complete-guide/run.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
generate_synthetic_data,
4848
llm_basic_rag,
4949
llm_eval,
50-
rag_deployment,
5150
llm_index_and_evaluate,
5251
local_deployment,
5352
production_deployment,
53+
rag_deployment,
5454
)
5555
from structures import Document
56-
from zenml.materializers.materializer_registry import materializer_registry
5756
from zenml import Model
57+
from zenml.materializers.materializer_registry import materializer_registry
5858

5959
logger = get_logger(__name__)
6060

@@ -150,7 +150,7 @@
150150
"env",
151151
default="local",
152152
help="The environment to use for the completion.",
153-
)
153+
)
154154
def main(
155155
pipeline: str,
156156
query_text: Optional[str] = None,
@@ -186,9 +186,9 @@ def main(
186186
}
187187
},
188188
}
189-
189+
190190
# Read the model version from a file in the root of the repo
191-
# called "ZENML_VERSION.txt".
191+
# called "ZENML_VERSION.txt".
192192
if zenml_model_version == "staging":
193193
postfix = "-rc0"
194194
elif zenml_model_version == "production":
@@ -200,8 +200,8 @@ def main(
200200
with open("ZENML_VERSION.txt", "r") as file:
201201
zenml_version = file.read().strip()
202202
zenml_version += postfix
203-
#zenml_model_version = file.read().strip()
204-
#zenml_model_version += postfix
203+
# zenml_model_version = file.read().strip()
204+
# zenml_model_version += postfix
205205
else:
206206
raise RuntimeError(
207207
"No model version file found. Please create a file called ZENML_VERSION.txt in the root of the repo with the model version."
@@ -294,7 +294,9 @@ def main(
294294

295295
elif pipeline == "embeddings":
296296
finetune_embeddings.with_options(
297-
model=zenml_model, config_path=config_path, **embeddings_finetune_args
297+
model=zenml_model,
298+
config_path=config_path,
299+
**embeddings_finetune_args,
298300
)()
299301

300302
elif pipeline == "chunks":
@@ -309,4 +311,4 @@ def main(
309311
materializer_registry.register_materializer_type(
310312
Document, DocumentMaterializer
311313
)
312-
main()
314+
main()

llm-complete-guide/service.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import asyncio
2-
from typing import Any, AsyncGenerator, Dict
1+
from typing import AsyncGenerator
32

43
import bentoml
54
import litellm
65
import numpy as np
76
from constants import (
8-
EMBEDDINGS_MODEL_ID_FINE_TUNED,
97
MODEL_NAME_MAP,
108
OPENAI_MODEL,
11-
SECRET_NAME,
129
SECRET_NAME_ELASTICSEARCH,
1310
)
1411
from elasticsearch import Elasticsearch
@@ -29,30 +26,43 @@
2926
http={
3027
"cors": {
3128
"enabled": True,
32-
"access_control_allow_origins": ["https://cloud.zenml.io"], # Add your allowed origins
33-
"access_control_allow_methods": ["GET", "OPTIONS", "POST", "HEAD", "PUT"],
29+
"access_control_allow_origins": [
30+
"https://cloud.zenml.io"
31+
], # Add your allowed origins
32+
"access_control_allow_methods": [
33+
"GET",
34+
"OPTIONS",
35+
"POST",
36+
"HEAD",
37+
"PUT",
38+
],
3439
"access_control_allow_credentials": True,
3540
"access_control_allow_headers": ["*"],
3641
# "access_control_allow_origin_regex": "https://.*\.my_org\.com", # Optional regex
3742
"access_control_max_age": 1200,
3843
"access_control_expose_headers": ["Content-Length"],
3944
}
40-
}
45+
},
4146
)
4247
class RAGService:
4348
"""RAG service for generating responses using LLM and RAG."""
49+
4450
def __init__(self):
4551
"""Initialize the RAG service."""
4652
# Initialize embeddings model
4753
self.embeddings_model = SentenceTransformer(EMBEDDINGS_MODEL)
48-
54+
4955
# Initialize reranker
5056
self.reranker = Reranker("flashrank")
51-
57+
5258
# Initialize Elasticsearch client
5359
client = Client()
54-
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
55-
es_api_key = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_api_key"]
60+
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values[
61+
"elasticsearch_host"
62+
]
63+
es_api_key = client.get_secret(
64+
SECRET_NAME_ELASTICSEARCH
65+
).secret_values["elasticsearch_api_key"]
5666
self.es_client = Elasticsearch(es_host, api_key=es_api_key)
5767

5868
def get_embeddings(self, text: str) -> np.ndarray:
@@ -62,40 +72,52 @@ def get_embeddings(self, text: str) -> np.ndarray:
6272
embeddings = embeddings[0]
6373
return embeddings
6474

65-
def get_similar_docs(self, query_embedding: np.ndarray, n: int = 20) -> list:
75+
def get_similar_docs(
76+
self, query_embedding: np.ndarray, n: int = 20
77+
) -> list:
6678
"""Get similar documents for the given query embedding."""
6779
if query_embedding.ndim == 2:
6880
query_embedding = query_embedding[0]
69-
70-
response = self.es_client.search(index="zenml_docs", knn={
71-
"field": "embedding",
72-
"query_vector": query_embedding.tolist(),
73-
"num_candidates": 50,
74-
"k": n
75-
})
76-
81+
82+
response = self.es_client.search(
83+
index="zenml_docs",
84+
knn={
85+
"field": "embedding",
86+
"query_vector": query_embedding.tolist(),
87+
"num_candidates": 50,
88+
"k": n,
89+
},
90+
)
91+
7792
docs = []
7893
for hit in response["hits"]["hits"]:
79-
docs.append({
80-
"content": hit["_source"]["content"],
81-
"url": hit["_source"]["url"],
82-
"parent_section": hit["_source"]["parent_section"]
83-
})
94+
docs.append(
95+
{
96+
"content": hit["_source"]["content"],
97+
"url": hit["_source"]["url"],
98+
"parent_section": hit["_source"]["parent_section"],
99+
}
100+
)
84101
return docs
85102

86103
def rerank_documents(self, query: str, documents: list) -> list:
87104
"""Rerank documents using the reranker."""
88-
docs_texts = [f"{doc['content']} PARENT SECTION: {doc['parent_section']}" for doc in documents]
105+
docs_texts = [
106+
f"{doc['content']} PARENT SECTION: {doc['parent_section']}"
107+
for doc in documents
108+
]
89109
results = self.reranker.rank(query=query, docs=docs_texts)
90-
110+
91111
reranked_docs = []
92112
for result in results.results:
93113
index_val = result.doc_id
94114
doc = documents[index_val]
95115
reranked_docs.append((result.text, doc["url"]))
96116
return reranked_docs[:5]
97117

98-
async def get_completion(self, messages: list, model: str, temperature: float, max_tokens: int) -> AsyncGenerator[str, None]:
118+
async def get_completion(
119+
self, messages: list, model: str, temperature: float, max_tokens: int
120+
) -> AsyncGenerator[str, None]:
99121
"""Handle the completion request and streaming response."""
100122
try:
101123
response = await litellm.acompletion(
@@ -104,9 +126,9 @@ async def get_completion(self, messages: list, model: str, temperature: float, m
104126
temperature=temperature,
105127
max_tokens=max_tokens,
106128
api_key=get_openai_api_key(),
107-
stream=True
129+
stream=True,
108130
)
109-
131+
110132
async for chunk in response:
111133
if chunk.choices and chunk.choices[0].delta.content:
112134
yield chunk.choices[0].delta.content
@@ -124,16 +146,16 @@ async def generate(
124146
try:
125147
# Get embeddings for query
126148
query_embedding = self.get_embeddings(query)
127-
149+
128150
# Retrieve similar documents
129151
similar_docs = self.get_similar_docs(query_embedding, n=20)
130-
152+
131153
# Rerank documents
132154
reranked_docs = self.rerank_documents(query, similar_docs)
133-
155+
134156
# Prepare context from reranked documents
135157
context = "\n\n".join([doc[0] for doc in reranked_docs])
136-
158+
137159
# Prepare system message
138160
system_message = """
139161
You are a friendly chatbot. \
@@ -149,15 +171,17 @@ async def generate(
149171
{"role": "system", "content": system_message},
150172
{"role": "user", "content": query},
151173
{
152-
"role": "assistant",
153-
"content": f"Please use the following relevant ZenML documentation to answer the query: \n{context}"
154-
}
174+
"role": "assistant",
175+
"content": f"Please use the following relevant ZenML documentation to answer the query: \n{context}",
176+
},
155177
]
156178

157179
# Get completion from LLM using the new async method
158180
model = MODEL_NAME_MAP.get(OPENAI_MODEL, OPENAI_MODEL)
159-
async for chunk in self.get_completion(messages, model, temperature, max_tokens):
181+
async for chunk in self.get_completion(
182+
messages, model, temperature, max_tokens
183+
):
160184
yield chunk
161-
185+
162186
except Exception as e:
163-
yield f"Error occurred: {str(e)}"
187+
yield f"Error occurred: {str(e)}"

0 commit comments

Comments
 (0)