1- import asyncio
2- from typing import Any , AsyncGenerator , Dict
1+ from typing import AsyncGenerator
32
43import bentoml
54import litellm
65import numpy as np
76from constants import (
8- EMBEDDINGS_MODEL_ID_FINE_TUNED ,
97 MODEL_NAME_MAP ,
108 OPENAI_MODEL ,
11- SECRET_NAME ,
129 SECRET_NAME_ELASTICSEARCH ,
1310)
1411from elasticsearch import Elasticsearch
2926 http = {
3027 "cors" : {
3128 "enabled" : True ,
32- "access_control_allow_origins" : ["https://cloud.zenml.io" ], # Add your allowed origins
33- "access_control_allow_methods" : ["GET" , "OPTIONS" , "POST" , "HEAD" , "PUT" ],
29+ "access_control_allow_origins" : [
30+ "https://cloud.zenml.io"
31+ ], # Add your allowed origins
32+ "access_control_allow_methods" : [
33+ "GET" ,
34+ "OPTIONS" ,
35+ "POST" ,
36+ "HEAD" ,
37+ "PUT" ,
38+ ],
3439 "access_control_allow_credentials" : True ,
3540 "access_control_allow_headers" : ["*" ],
3641 # "access_control_allow_origin_regex": "https://.*\.my_org\.com", # Optional regex
3742 "access_control_max_age" : 1200 ,
3843 "access_control_expose_headers" : ["Content-Length" ],
3944 }
40- }
45+ },
4146)
4247class RAGService :
4348 """RAG service for generating responses using LLM and RAG."""
49+
4450 def __init__ (self ):
4551 """Initialize the RAG service."""
4652 # Initialize embeddings model
4753 self .embeddings_model = SentenceTransformer (EMBEDDINGS_MODEL )
48-
54+
4955 # Initialize reranker
5056 self .reranker = Reranker ("flashrank" )
51-
57+
5258 # Initialize Elasticsearch client
5359 client = Client ()
54- es_host = client .get_secret (SECRET_NAME_ELASTICSEARCH ).secret_values ["elasticsearch_host" ]
55- es_api_key = client .get_secret (SECRET_NAME_ELASTICSEARCH ).secret_values ["elasticsearch_api_key" ]
60+ es_host = client .get_secret (SECRET_NAME_ELASTICSEARCH ).secret_values [
61+ "elasticsearch_host"
62+ ]
63+ es_api_key = client .get_secret (
64+ SECRET_NAME_ELASTICSEARCH
65+ ).secret_values ["elasticsearch_api_key" ]
5666 self .es_client = Elasticsearch (es_host , api_key = es_api_key )
5767
5868 def get_embeddings (self , text : str ) -> np .ndarray :
@@ -62,40 +72,52 @@ def get_embeddings(self, text: str) -> np.ndarray:
6272 embeddings = embeddings [0 ]
6373 return embeddings
6474
65- def get_similar_docs (self , query_embedding : np .ndarray , n : int = 20 ) -> list :
75+ def get_similar_docs (
76+ self , query_embedding : np .ndarray , n : int = 20
77+ ) -> list :
6678 """Get similar documents for the given query embedding."""
6779 if query_embedding .ndim == 2 :
6880 query_embedding = query_embedding [0 ]
69-
70- response = self .es_client .search (index = "zenml_docs" , knn = {
71- "field" : "embedding" ,
72- "query_vector" : query_embedding .tolist (),
73- "num_candidates" : 50 ,
74- "k" : n
75- })
76-
81+
82+ response = self .es_client .search (
83+ index = "zenml_docs" ,
84+ knn = {
85+ "field" : "embedding" ,
86+ "query_vector" : query_embedding .tolist (),
87+ "num_candidates" : 50 ,
88+ "k" : n ,
89+ },
90+ )
91+
7792 docs = []
7893 for hit in response ["hits" ]["hits" ]:
79- docs .append ({
80- "content" : hit ["_source" ]["content" ],
81- "url" : hit ["_source" ]["url" ],
82- "parent_section" : hit ["_source" ]["parent_section" ]
83- })
94+ docs .append (
95+ {
96+ "content" : hit ["_source" ]["content" ],
97+ "url" : hit ["_source" ]["url" ],
98+ "parent_section" : hit ["_source" ]["parent_section" ],
99+ }
100+ )
84101 return docs
85102
86103 def rerank_documents (self , query : str , documents : list ) -> list :
87104 """Rerank documents using the reranker."""
88- docs_texts = [f"{ doc ['content' ]} PARENT SECTION: { doc ['parent_section' ]} " for doc in documents ]
105+ docs_texts = [
106+ f"{ doc ['content' ]} PARENT SECTION: { doc ['parent_section' ]} "
107+ for doc in documents
108+ ]
89109 results = self .reranker .rank (query = query , docs = docs_texts )
90-
110+
91111 reranked_docs = []
92112 for result in results .results :
93113 index_val = result .doc_id
94114 doc = documents [index_val ]
95115 reranked_docs .append ((result .text , doc ["url" ]))
96116 return reranked_docs [:5 ]
97117
98- async def get_completion (self , messages : list , model : str , temperature : float , max_tokens : int ) -> AsyncGenerator [str , None ]:
118+ async def get_completion (
119+ self , messages : list , model : str , temperature : float , max_tokens : int
120+ ) -> AsyncGenerator [str , None ]:
99121 """Handle the completion request and streaming response."""
100122 try :
101123 response = await litellm .acompletion (
@@ -104,9 +126,9 @@ async def get_completion(self, messages: list, model: str, temperature: float, m
104126 temperature = temperature ,
105127 max_tokens = max_tokens ,
106128 api_key = get_openai_api_key (),
107- stream = True
129+ stream = True ,
108130 )
109-
131+
110132 async for chunk in response :
111133 if chunk .choices and chunk .choices [0 ].delta .content :
112134 yield chunk .choices [0 ].delta .content
@@ -124,16 +146,16 @@ async def generate(
124146 try :
125147 # Get embeddings for query
126148 query_embedding = self .get_embeddings (query )
127-
149+
128150 # Retrieve similar documents
129151 similar_docs = self .get_similar_docs (query_embedding , n = 20 )
130-
152+
131153 # Rerank documents
132154 reranked_docs = self .rerank_documents (query , similar_docs )
133-
155+
134156 # Prepare context from reranked documents
135157 context = "\n \n " .join ([doc [0 ] for doc in reranked_docs ])
136-
158+
137159 # Prepare system message
138160 system_message = """
139161 You are a friendly chatbot. \
@@ -149,15 +171,17 @@ async def generate(
149171 {"role" : "system" , "content" : system_message },
150172 {"role" : "user" , "content" : query },
151173 {
152- "role" : "assistant" ,
153- "content" : f"Please use the following relevant ZenML documentation to answer the query: \n { context } "
154- }
174+ "role" : "assistant" ,
175+ "content" : f"Please use the following relevant ZenML documentation to answer the query: \n { context } " ,
176+ },
155177 ]
156178
157179 # Get completion from LLM using the new async method
158180 model = MODEL_NAME_MAP .get (OPENAI_MODEL , OPENAI_MODEL )
159- async for chunk in self .get_completion (messages , model , temperature , max_tokens ):
181+ async for chunk in self .get_completion (
182+ messages , model , temperature , max_tokens
183+ ):
160184 yield chunk
161-
185+
162186 except Exception as e :
163- yield f"Error occurred: { str (e )} "
187+ yield f"Error occurred: { str (e )} "
0 commit comments