Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepsearcher/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def main():
nargs="+", # 1 or more files or urls
help="Load knowledge from local files or from URLs.",
)
load_parser.add_argument(
"--batch_size",
type=int,
default=256,
help="Batch size for loading knowledge.",
)
load_parser.add_argument(
"--collection_name",
type=str,
Expand Down Expand Up @@ -88,6 +94,8 @@ def main():
kwargs["collection_description"] = args.collection_desc
if args.force_new_collection:
kwargs["force_new_collection"] = args.force_new_collection
if args.batch_size:
kwargs["batch_size"] = args.batch_size
if len(urls) > 0:
load_from_website(urls, **kwargs)
if len(local_files) > 0:
Expand Down
2 changes: 1 addition & 1 deletion deepsearcher/embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def embed_query(self, text: str) -> List[float]:
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_chunks(self, chunks: List[Chunk], batch_size=256) -> List[Chunk]:
def embed_chunks(self, chunks: List[Chunk], batch_size: int = 256) -> List[Chunk]:
texts = [chunk.text for chunk in chunks]
batch_texts = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
embeddings = []
Expand Down
10 changes: 6 additions & 4 deletions deepsearcher/offline_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def load_from_local_files(
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,
chunk_size=1500,
chunk_overlap=100,
chunk_size: int = 1500,
chunk_overlap: int = 100,
batch_size: int = 256,
):
vector_db = configuration.vector_db
if collection_name is None:
Expand Down Expand Up @@ -46,7 +47,7 @@ def load_from_local_files(
chunk_overlap=chunk_overlap,
)

chunks = embedding_model.embed_chunks(chunks)
chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=chunks)


Expand All @@ -55,6 +56,7 @@ def load_from_website(
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,
batch_size: int = 256,
**crawl_kwargs,
):
if isinstance(urls, str):
Expand All @@ -73,5 +75,5 @@ def load_from_website(
all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs)

chunks = split_docs_to_chunks(all_docs)
chunks = embedding_model.embed_chunks(chunks)
chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=chunks)
12 changes: 12 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,18 @@ def load_files(
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
try:
load_from_local_files(
paths_or_directory=paths,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size,
)
return {"message": "Files loaded successfully."}
except Exception as e:
Expand All @@ -83,12 +89,18 @@ def load_website(
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
try:
load_from_website(
urls=urls,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size,
)
return {"message": "Website loaded successfully."}
except Exception as e:
Expand Down