Skip to content

Commit 9ab649f

Browse files
exowandererrti
authored andcommitted
created a rag_pipeline in the rag.py based on the usage in api.py; removed rag_piipeline from api.py; introduced rag_pipeline from rag.py into api.py
1 parent 0d12a3a commit 9ab649f

File tree

3 files changed

+65
-69
lines changed

3 files changed

+65
-69
lines changed

gswikichat/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
# from .logger import logger
21
from .api import *

gswikichat/api.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
from fastapi.staticfiles import StaticFiles
33
from fastapi import FastAPI
44

5-
from .rag import answer_builder
6-
from .llm_config import llm
7-
from .prompt import prompt_builders
8-
from .vector_store_interface import embedder, retriever, input_documents
5+
from .rag import rag_pipeline
96

107
from haystack import Document
118
from .logger import get_logger
@@ -40,51 +37,14 @@ async def api(query, top_k=3, lang='en'):
4037

4138
logger.debug(f'{query=}') # Assuming we change the input name
4239
logger.debug(f'{top_k=}')
43-
logger.debug(f'{top_k=}')
44-
45-
query = Document(content=query)
40+
logger.debug(f'{lang=}')
4641

47-
query_embedded = embedder.run([query])
48-
query_embedding = query_embedded['documents'][0].embedding
49-
50-
retriever_results = retriever.run(
51-
query_embedding=list(query_embedding),
52-
filters=None,
42+
answer = rag_pipeline(
43+
query=query,
5344
top_k=top_k,
54-
scale_score=None,
55-
return_embedding=None
56-
)
57-
58-
logger.debug('retriever results:')
59-
for retriever_result_ in retriever_results:
60-
logger.debug(retriever_result_)
61-
62-
prompt_builder = prompt_builders[lang]
63-
64-
prompt_build = prompt_builder.run(
65-
question=query.content, # As a Document instance, .content returns a string
66-
documents=retriever_results['documents']
45+
lang=lang
6746
)
6847

69-
prompt = prompt_build['prompt']
70-
71-
logger.debug(f'{prompt=}')
72-
73-
response = llm.run(prompt=prompt, generation_kwargs=None)
74-
75-
answer_build = answer_builder.run(
76-
query=query.content, # As a Document class, .content returns the string
77-
replies=response['replies'],
78-
meta=response['meta'],
79-
documents=retriever_results['documents'],
80-
pattern=None,
81-
reference_pattern=None
82-
)
83-
84-
logger.debug(f'{answer_build=}')
85-
86-
answer = answer_build['answers'][0]
87-
8848
sources = [
8949
{
9050
"src": d_.meta['src'],

gswikichat/rag.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,63 @@
11
# from haystack import Pipeline
2+
from haystack import Document
23
from haystack.components.builders.answer_builder import AnswerBuilder
34

4-
answer_builder = AnswerBuilder()
5-
6-
# rag_pipeline = Pipeline()
7-
# rag_pipeline.add_component("text_embedder", embedder)
8-
# rag_pipeline.add_component("retriever", retriever)
9-
# # rag_pipeline.add_component("writer", writer)
10-
# rag_pipeline.add_component("prompt_builder", prompt_builder)
11-
# rag_pipeline.add_component("llm", llm)
12-
# rag_pipeline.add_component("answer_builder", answer_builder)
13-
14-
# # rag_pipeline.connect("embedder", "writer")
15-
# rag_pipeline.connect("retriever.documents", "text_embedder")
16-
# rag_pipeline.connect("retriever", "prompt_builder.documents")
17-
# rag_pipeline.connect("prompt_builder", "llm")
18-
# rag_pipeline.connect("llm.replies", "answer_builder.replies")
19-
# rag_pipeline.connect("llm.metadata", "answer_builder.meta")
20-
# rag_pipeline.connect("retriever", "answer_builder.documents")
21-
22-
# rag_pipeline.run(
23-
# {
24-
# "text_embedder": {"documents": input_documents}
25-
# }
26-
# )
5+
from .llm_config import llm
6+
from .logger import get_logger
7+
from .prompt import prompt_builders
8+
from .vector_store_interface import embedder, retriever, input_documents
9+
10+
# Create logger instance from base logger config in `logger.py`
11+
logger = get_logger(__name__)
12+
13+
14+
def rag_pipeline(query: str = None, top_k: int = 3, lang: str = 'de'):
15+
16+
assert (query is not None)
17+
18+
if isinstance(query, str):
19+
query = Document(content=query)
20+
21+
assert (isinstance(query, Document))
22+
23+
query_embedded = embedder.run([query])
24+
query_embedding = query_embedded['documents'][0].embedding
25+
26+
retriever_results = retriever.run(
27+
query_embedding=list(query_embedding),
28+
filters=None,
29+
top_k=top_k,
30+
scale_score=None,
31+
return_embedding=None
32+
)
33+
34+
logger.debug('retriever results:')
35+
for retriever_result_ in retriever_results:
36+
logger.debug(retriever_result_)
37+
38+
prompt_builder = prompt_builders[lang]
39+
40+
prompt_build = prompt_builder.run(
41+
question=query.content, # As a Document instance, .content returns a string
42+
documents=retriever_results['documents']
43+
)
44+
45+
prompt = prompt_build['prompt']
46+
47+
logger.debug(f'{prompt=}')
48+
49+
response = llm.run(prompt=prompt, generation_kwargs=None)
50+
51+
answer_builder = AnswerBuilder()
52+
answer_build = answer_builder.run(
53+
query=query.content, # As a Document class, .content returns the string
54+
replies=response['replies'],
55+
meta=response['meta'],
56+
documents=retriever_results['documents'],
57+
pattern=None,
58+
reference_pattern=None
59+
)
60+
61+
logger.debug(f'{answer_build=}')
62+
63+
return answer_build['answers'][0]

0 commit comments

Comments
 (0)