Skip to content

Commit 4ecfca9

Browse files
authored
Merge pull request #51 from whyhow-ai/40-feature-seperate-embeddings-and-llm-inference
WIP: Updating references and separating concerns.
2 parents 50d927b + f59dd83 commit 4ecfca9

30 files changed

+454
-517
lines changed

backend/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Added support for queries without source data in vector database
1313
- Graceful failure of triple export when no chunks are found
1414

15+
### Changed
16+
17+
- Separated embedding service from LLM service
18+
1519
## [v0.1.5] - 2024-10-29
1620

1721
### Changed

backend/docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ services:
1414
environment:
1515
- ENVIRONMENT=dev
1616
- TESTING=0
17+
env_file:
18+
- .env
1719
# Add healthcheck for better orchestration
1820
healthcheck:
1921
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]

backend/pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ dependencies = [
6565
"pydantic-settings>=2.5.2",
6666
"pydantic_core>=2.23.3",
6767
"pymilvus>=2.4.6",
68-
"PyMuPDF>=1.24.10",
69-
"PyMuPDFb>=1.24.10",
7068
"pypdf>=5.0.0",
7169
"PyPDF2>=3.0.1",
7270
"python-dateutil>=2.9.0",

backend/src/app/api/v1/endpoints/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ExportTriplesResponseSchema,
1919
)
2020
from app.services.graph_service import generate_triples
21-
from app.services.llm.base import LLMService
21+
from app.services.llm.base import CompletionService
2222
from app.services.llm_service import generate_schema
2323

2424
router = APIRouter(tags=["Graph"])
@@ -28,7 +28,7 @@
2828
@router.post("/export-triples", response_model=ExportTriplesResponseSchema)
2929
async def export_triples(
3030
request: ExportTriplesRequestSchema,
31-
llm_service: LLMService = Depends(get_llm_service),
31+
llm_service: CompletionService = Depends(get_llm_service),
3232
) -> ExportTriplesResponseSchema:
3333
"""
3434
Generate and export triples from a table.
@@ -40,7 +40,7 @@ async def export_triples(
4040
----------
4141
request : ExportTriplesRequestSchema
4242
The request body containing the table data (columns, rows, and cells).
43-
llm_service : LLMService
43+
llm_service : CompletionService
4444
The language model service used for generating the schema, injected by FastAPI.
4545
4646
Returns

backend/src/app/api/v1/endpoints/query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
QueryRequestSchema,
1313
QueryResult,
1414
)
15-
from app.services.llm.base import LLMService
15+
from app.services.llm.base import CompletionService
1616
from app.services.query_service import (
1717
decomposition_query,
1818
hybrid_query,
@@ -31,7 +31,7 @@
3131
@router.post("", response_model=QueryAnswerResponse)
3232
async def run_query(
3333
request: QueryRequestSchema,
34-
llm_service: LLMService = Depends(get_llm_service),
34+
llm_service: CompletionService = Depends(get_llm_service),
3535
vector_db_service: VectorDBService = Depends(get_vector_db_service),
3636
) -> QueryAnswerResponse:
3737
"""
@@ -45,7 +45,7 @@ async def run_query(
4545
----------
4646
request : QueryRequestSchema
4747
The incoming query request.
48-
llm_service : LLMService
48+
llm_service : CompletionService
4949
The language model service.
5050
vector_db_service : VectorDBService
5151
The vector database service.

backend/src/app/core/dependencies.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,46 @@
44

55
from app.core.config import Settings, get_settings
66
from app.services.document_service import DocumentService
7-
from app.services.llm.base import LLMService
8-
from app.services.llm.factory import LLMFactory
7+
from app.services.embedding.base import EmbeddingService
8+
from app.services.embedding.factory import EmbeddingServiceFactory
9+
from app.services.llm.base import CompletionService
10+
from app.services.llm.factory import CompletionServiceFactory
911
from app.services.vector_db.base import VectorDBService
1012
from app.services.vector_db.factory import VectorDBFactory
1113

1214

13-
def get_llm_service(settings: Settings = Depends(get_settings)) -> LLMService:
15+
def get_llm_service(
16+
settings: Settings = Depends(get_settings),
17+
) -> CompletionService:
1418
"""Get the LLM service for the application."""
15-
llm_service = LLMFactory.create_llm_service(settings)
19+
llm_service = CompletionServiceFactory.create_service(settings)
1620
if llm_service is None:
1721
raise ValueError(
1822
f"Failed to create LLM service for provider: {settings.llm_provider}"
1923
)
2024
return llm_service
2125

2226

27+
def get_embedding_service(
28+
settings: Settings = Depends(get_settings),
29+
) -> EmbeddingService:
30+
"""Get the embedding service for the application."""
31+
embedding_service = EmbeddingServiceFactory.create_service(settings)
32+
if embedding_service is None:
33+
raise ValueError(
34+
f"Failed to create embedding service for provider: {settings.embedding_provider}"
35+
)
36+
return embedding_service
37+
38+
2339
def get_vector_db_service(
2440
settings: Settings = Depends(get_settings),
25-
llm_service: LLMService = Depends(get_llm_service),
41+
embedding_service: EmbeddingService = Depends(get_embedding_service),
42+
llm_service: CompletionService = Depends(get_llm_service),
2643
) -> VectorDBService:
2744
"""Get the vector database service for the application."""
2845
vector_db_service = VectorDBFactory.create_vector_db_service(
29-
llm_service, settings
46+
embedding_service, llm_service, settings
3047
)
3148
if vector_db_service is None:
3249
raise ValueError(
@@ -38,7 +55,7 @@ def get_vector_db_service(
3855
def get_document_service(
3956
settings: Settings = Depends(get_settings),
4057
vector_db_service: VectorDBService = Depends(get_vector_db_service),
41-
llm_service: LLMService = Depends(get_llm_service),
58+
llm_service: CompletionService = Depends(get_llm_service),
4259
) -> DocumentService:
4360
"""Get the document service for the application."""
4461
return DocumentService(vector_db_service, llm_service, settings)

backend/src/app/services/document_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain.text_splitter import RecursiveCharacterTextSplitter
1111

1212
from app.core.config import Settings
13-
from app.services.llm.base import LLMService
13+
from app.services.llm.base import CompletionService
1414
from app.services.loaders.factory import LoaderFactory
1515
from app.services.vector_db.base import VectorDBService
1616

@@ -23,7 +23,7 @@ class DocumentService:
2323
def __init__(
2424
self,
2525
vector_db_service: VectorDBService,
26-
llm_service: LLMService,
26+
llm_service: CompletionService,
2727
settings: Settings,
2828
):
2929
"""Document service."""
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Embedding service module."""
2+
3+
from app.services.embedding.base import EmbeddingService
4+
from app.services.embedding.factory import EmbeddingServiceFactory
5+
6+
__all__ = ["EmbeddingService", "EmbeddingServiceFactory"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Abstract base class for embedding services."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import List
5+
6+
7+
class EmbeddingService(ABC):
8+
"""Abstract base class for embedding services."""
9+
10+
@abstractmethod
11+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
12+
"""Get the embeddings for the given text."""
13+
pass
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Factory for creating embedding services."""
2+
3+
import logging
4+
from typing import Optional
5+
6+
from app.core.config import Settings
7+
from app.services.embedding.base import EmbeddingService
8+
from app.services.embedding.openai_embedding_service import (
9+
OpenAIEmbeddingService,
10+
)
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class EmbeddingServiceFactory:
16+
"""Factory for creating embedding services."""
17+
18+
@staticmethod
19+
def create_service(settings: Settings) -> Optional[EmbeddingService]:
20+
"""Create an embedding service."""
21+
logger.info(
22+
f"Creating embedding service for provider: {settings.embedding_provider}"
23+
)
24+
if settings.embedding_provider == "openai":
25+
return OpenAIEmbeddingService(settings)
26+
# Add more providers here when needed
27+
return None

0 commit comments

Comments
 (0)