-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathknowledge_base_service.py
More file actions
315 lines (258 loc) · 10 KB
/
knowledge_base_service.py
File metadata and controls
315 lines (258 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import os
import uuid
from markitdown import MarkItDown
from typing import List, Dict, Any, Optional
from app.core.config import settings
from app.core.singleton import get_openai_client, get_qdrant_client
from langchain.text_splitter import RecursiveCharacterTextSplitter
from app.services.qdrant_service import (
create_qdrant_collection,
insert_qdrant_points,
search_qdrant_points,
delete_qdrant_documents,
)
from langfuse import observe
openai_client = get_openai_client()
qdrant_client = get_qdrant_client()
markitdown_client = MarkItDown(
llm_client=openai_client, model=settings.AZURE_DEPLOYMENT_NAME
)
KB_COLLECTION_NAME = settings.QDRANT_COLLECTION_NAME
def create_knowledge_base_collection_if_not_exists():
"""
Create a knowledge base collection in Qdrant if it does not exist.
"""
try:
# Check if collection exists
collections = qdrant_client.get_collections()
collection_names = [col.name for col in collections.collections]
if KB_COLLECTION_NAME not in collection_names:
# Create collection with 1536 dimensions (Azure OpenAI text-embedding-ada-002)
result = create_qdrant_collection(
collection_name=KB_COLLECTION_NAME, vector_size=1536, distance="cosine"
)
return {"status": "success", "message": result}
else:
return {
"status": "success",
"message": f"Collection {KB_COLLECTION_NAME} already exists.",
}
except Exception as e:
return {"status": "error", "message": f"Error creating collection: {str(e)}"}
def embed_text(text: str):
"""
Embed a text string using Azure OpenAI.
Args:
text (str): The text to embed.
"""
response = openai_client.embeddings.create(
input=text, model=settings.AZURE_EMBEDDING_DEPLOYMENT_NAME
)
return response.data[0].embedding
def chunk_text(text: str):
"""
Chunk a text string into smaller chunks using LangChain RecursiveCharacterTextSplitter.
The chunk size is 1000 characters and the overlap is 200 characters.
Args:
text (str): The text to chunk.
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1250, chunk_overlap=250)
return text_splitter.split_text(text)
def get_knowledge_base_files(collection_name: str = KB_COLLECTION_NAME):
"""
Get all files names from a collection in Qdrant.
Args:
collection_name (str): The name of the collection to get the files from.
"""
try:
# Get all points from the collection
points = qdrant_client.scroll(
collection_name=collection_name,
limit=1000, # Adjust as needed
with_payload=True,
)[0]
# Extract unique file names from payloads
file_names = set()
for point in points:
if point.payload and "file_name" in point.payload:
file_names.add(point.payload["file_name"])
return {"status": "success", "files": list(file_names)}
except Exception as e:
return {"status": "error", "message": f"Error getting files: {str(e)}"}
def extract_text_from_pdf(file_path: str) -> str:
"""
Extract text from a PDF file.
Args:
file_path (str): Path to the PDF file.
Returns:
str: Extracted text from the PDF.
"""
try:
result = markitdown_client.convert(file_path)
return result.text_content
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def upload_pdf_file(file_path: str, file_name: str = None):
"""
Upload a PDF file to a collection in Qdrant.
The file is converted to text and then chunked into smaller chunks.
The chunks are then embedded and uploaded to the collection.
Args:
file_path (str): Path to the PDF file to upload.
file_name (str): Name to use for the file in the collection. If None, uses the original filename.
"""
try:
# Ensure collection exists
create_knowledge_base_collection_if_not_exists()
# Use provided file_name or extract from path
if file_name is None:
file_name = os.path.basename(file_path)
# Extract text from PDF
text = extract_text_from_pdf(file_path)
# Chunk the text
chunks = chunk_text(text)
# Embed and upload each chunk
for i, chunk in enumerate(chunks):
# Create unique point ID
point_id = int(uuid.uuid4().hex[:16], 16)
# Embed the chunk
embedding = embed_text(chunk)
# Create payload with file name and chunk
payload = {
"file_name": file_name,
"chunk": chunk,
"chunk_index": i,
"total_chunks": len(chunks),
}
# Insert into Qdrant
insert_qdrant_points(
collection_name=KB_COLLECTION_NAME,
point_id=point_id,
vector=embedding,
payload=payload,
)
return {
"status": "success",
"message": f"PDF file {file_name} uploaded successfully with {len(chunks)} chunks.",
}
except Exception as e:
return {"status": "error", "message": f"Error uploading PDF file: {str(e)}"}
def upload_text_file(file_path: str, file_name: str = None):
"""
Upload a text file to a collection in Qdrant.
The file is converted to text and then chunked into smaller chunks.
The chunks are then embedded and uploaded to the collection.
Args:
file_path (str): Path to the text file to upload.
file_name (str): Name to use for the file in the collection. If None, uses the original filename.
"""
try:
# Ensure collection exists
create_knowledge_base_collection_if_not_exists()
# Use provided file_name or extract from path
if file_name is None:
file_name = os.path.basename(file_path)
# Read text from file
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
# Chunk the text
chunks = chunk_text(text)
# Embed and upload each chunk
for i, chunk in enumerate(chunks):
# Create unique point ID
point_id = int(uuid.uuid4().hex[:16], 16)
# Embed the chunk
embedding = embed_text(chunk)
# Create payload with file name and chunk
payload = {
"file_name": file_name,
"chunk": chunk,
"chunk_index": i,
"total_chunks": len(chunks),
}
# Insert into Qdrant
insert_qdrant_points(
collection_name=KB_COLLECTION_NAME,
point_id=point_id,
vector=embedding,
payload=payload,
)
return {
"status": "success",
"message": f"Text file {file_name} uploaded successfully with {len(chunks)} chunks.",
}
except Exception as e:
return {"status": "error", "message": f"Error uploading text file: {str(e)}"}
def delete_knowledge_base_file(
file_name: str, collection_name: str = KB_COLLECTION_NAME
):
"""
Delete a file from a collection in Qdrant.
Args:
file_name (str): The name of the file to delete.
collection_name (str): The name of the collection to delete from.
"""
try:
result = delete_qdrant_documents(collection_name, file_name)
return result
except Exception as e:
return {"status": "error", "message": f"Error deleting file: {str(e)}"}
@observe(name="GET_SIMILAR_CHUNKS")
def get_similar_chunks(
collection_name: str = KB_COLLECTION_NAME, query: str = "", limit: int = 10
):
"""
Get similar chunks from a collection in Qdrant.
Args:
collection_name (str): The name of the collection to get the similar chunks from.
query (str): The query to get the similar chunks from.
limit (int): The number of similar chunks to return.
"""
try:
# Embed the query
query_embedding = embed_text(query)
# Search for similar chunks
result = search_qdrant_points(
collection_name=collection_name, vector=query_embedding, limit=limit
)
# Format the results
formatted_results = []
for point in result["message"]:
formatted_results.append(
{
"score": point.score,
"file_name": point.payload.get("file_name", ""),
"chunk": point.payload.get("chunk", ""),
}
)
return {"status": "success", "results": formatted_results}
except Exception as e:
return {"status": "error", "message": f"Error getting similar chunks: {str(e)}"}
def get_file_chunks(file_name: str, collection_name: str = KB_COLLECTION_NAME):
"""
Get all chunks for a specific file from the knowledge base.
Args:
file_name (str): The name of the file to get chunks for.
collection_name (str): The name of the collection to search in.
"""
try:
# Get all points from the collection
points = qdrant_client.scroll(
collection_name=collection_name, limit=1000, with_payload=True
)[0]
# Filter points for the specific file
file_chunks = []
for point in points:
if point.payload and point.payload.get("file_name") == file_name:
file_chunks.append(
{
"chunk_index": point.payload.get("chunk_index", 0),
"chunk": point.payload.get("chunk", ""),
"total_chunks": point.payload.get("total_chunks", 0),
}
)
# Sort by chunk index
file_chunks.sort(key=lambda x: x["chunk_index"])
return {"status": "success", "chunks": file_chunks}
except Exception as e:
return {"status": "error", "message": f"Error getting file chunks: {str(e)}"}