Skip to content

Commit f59dd83

Browse files
committed
Updating tests to seperate concerns from providers
1 parent 3e0ab3b commit f59dd83

File tree

8 files changed

+239
-451
lines changed

8 files changed

+239
-451
lines changed

backend/tests/conftest.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
1-
import os
2-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock
32

43
import pytest
54
from fastapi.testclient import TestClient
65

76
from app import main
87
from app.core.config import Settings, get_settings
8+
from app.models.query_core import Chunk
9+
from app.schemas.query_api import QueryResult
910
from app.services.document_service import DocumentService
10-
from app.services.embedding.openai_embedding_service import (
11-
OpenAIEmbeddingService,
12-
)
13-
from app.services.llm.factory import CompletionServiceFactory
14-
from app.services.llm.openai_llm_service import OpenAICompletionService
11+
from app.services.embedding.base import EmbeddingService
12+
from app.services.llm.base import CompletionService
1513
from app.services.vector_db.base import VectorDBService
16-
from app.services.vector_db.factory import VectorDBFactory
1714

1815

1916
def get_settings_override():
2017
return Settings(
2118
testing=True,
22-
database_url=os.environ.get("DATABASE_TEST_URL"),
19+
database_url="test_db_url",
2320
chunk_size=1000,
2421
chunk_overlap=200,
2522
loader="test_loader",
2623
vector_db_provider="test_vector_db",
27-
llm_provider="test_llm",
28-
openai_api_key="test_api_key",
24+
llm_provider="test_provider",
25+
embedding_provider="test_provider",
26+
openai_api_key=None,
2927
embedding_model="test_embedding_model",
3028
dimensions=1536,
3129
llm_model="test_llm_model",
@@ -37,16 +35,6 @@ def test_settings():
3735
return get_settings_override()
3836

3937

40-
@pytest.fixture(scope="session")
41-
def mock_openai_client():
42-
return MagicMock()
43-
44-
45-
@pytest.fixture(scope="session")
46-
def mock_openai_embeddings():
47-
return MagicMock()
48-
49-
5038
@pytest.fixture(scope="session")
5139
def test_app(test_settings):
5240
main.app.dependency_overrides[get_settings] = lambda: test_settings
@@ -61,45 +49,64 @@ def client(test_app):
6149

6250

6351
@pytest.fixture(scope="session")
64-
def mock_vector_db_service(mock_embeddings_service):
65-
service = AsyncMock(spec=VectorDBService)
66-
service.embedding_service = mock_embeddings_service
52+
def mock_embeddings_service(test_settings):
53+
service = AsyncMock(spec=EmbeddingService)
54+
service.get_embeddings.return_value = [[0.1, 0.2, 0.3]]
55+
service.settings = test_settings
56+
service.model = test_settings.embedding_model
6757
return service
6858

6959

7060
@pytest.fixture(scope="session")
71-
def mock_embeddings_service():
72-
service = AsyncMock(spec=OpenAIEmbeddingService)
73-
service.get_embeddings.return_value = [0.1, 0.2, 0.3]
61+
def mock_llm_service(test_settings):
62+
service = AsyncMock(spec=CompletionService)
63+
service.settings = test_settings
64+
service.model = test_settings.llm_model
65+
66+
# Create a mock response that matches test expectations
67+
mock_response = QueryResult(
68+
answer="The capital of France is Paris.",
69+
chunks=[Chunk(content="Paris is the capital of France.", page=1)],
70+
)
71+
72+
# Update the mock methods to return the same response
73+
service.generate_completion = AsyncMock(return_value=mock_response)
74+
service.generate_response = AsyncMock(return_value=mock_response)
7475
return service
7576

7677

7778
@pytest.fixture(scope="session")
78-
def mock_llm_service():
79-
service = AsyncMock(spec=OpenAICompletionService)
80-
service.client = MagicMock()
81-
service.generate_completion.return_value = "Mocked completion"
79+
def mock_vector_db_service(
80+
mock_embeddings_service, mock_llm_service, test_settings
81+
):
82+
service = AsyncMock(spec=VectorDBService)
83+
service.embedding_service = mock_embeddings_service
84+
service.llm_service = mock_llm_service
85+
service.settings = test_settings
8286
return service
8387

8488

85-
@pytest.fixture
86-
def mock_factories(mock_llm_service, mock_vector_db_service):
89+
@pytest.fixture(scope="session")
90+
def document_service(test_settings, mock_vector_db_service, mock_llm_service):
91+
return DocumentService(
92+
mock_vector_db_service, mock_llm_service, test_settings
93+
)
94+
95+
96+
@pytest.fixture(scope="session", autouse=True)
97+
def mock_dependencies(
98+
mock_embeddings_service, mock_llm_service, mock_vector_db_service
99+
):
100+
"""Mock dependencies for API endpoints only"""
87101
with pytest.MonkeyPatch.context() as m:
102+
# Only mock the embedding and LLM factories
88103
m.setattr(
89-
CompletionServiceFactory,
90-
"create_service",
91-
lambda *args, **kwargs: mock_llm_service,
104+
"app.services.embedding.factory.EmbeddingServiceFactory.create_service",
105+
lambda *args, **kwargs: mock_embeddings_service,
92106
)
93107
m.setattr(
94-
VectorDBFactory,
95-
"create_vector_db_service",
96-
lambda *args, **kwargs: mock_vector_db_service,
108+
"app.services.llm.factory.CompletionServiceFactory.create_service",
109+
lambda *args, **kwargs: mock_llm_service,
97110
)
111+
# Don't mock the vector DB factory as we want to test it
98112
yield
99-
100-
101-
@pytest.fixture(scope="session")
102-
def document_service(test_settings, mock_vector_db_service, mock_llm_service):
103-
return DocumentService(
104-
mock_vector_db_service, mock_llm_service, test_settings
105-
)

backend/tests/test_endpoint_graph.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1+
# test_endpoint_graph.py
12
import json
23
from unittest.mock import AsyncMock, MagicMock, patch
34

45
import pytest
56
from fastapi import status
6-
from fastapi.testclient import TestClient
77

8-
from app.core.config import Settings
98
from app.models.document import Document
109
from app.models.graph import GraphChunk, Node, Relation, Triple
1110
from app.models.table import Column, Row, TablePrompt
1211
from app.schemas.graph_api import ExportTriplesResponseSchema
1312

1413

15-
@pytest.fixture
16-
def mock_openai_client():
17-
return MagicMock()
18-
19-
20-
@pytest.fixture
21-
def mock_openai_embeddings():
22-
return MagicMock()
23-
24-
2514
@pytest.fixture
2615
def mock_generate_schema():
2716
return AsyncMock()
@@ -32,16 +21,6 @@ def mock_generate_triples():
3221
return AsyncMock()
3322

3423

35-
@pytest.fixture
36-
def client():
37-
with patch("app.core.dependencies.get_settings") as mock_get_settings:
38-
mock_get_settings.return_value = Settings(openai_api_key=None)
39-
from app.main import app
40-
41-
with TestClient(app) as test_client:
42-
yield test_client
43-
44-
4524
def create_test_prompt():
4625
return TablePrompt(
4726
entityType="test",
@@ -166,6 +145,7 @@ def test_export_triples_success(
166145
],
167146
)
168147

148+
# Remove the patch for openai.OpenAI since it's mocked globally
169149
with (
170150
patch(
171151
"app.api.v1.endpoints.graph.get_llm_service",
@@ -178,7 +158,6 @@ def test_export_triples_success(
178158
"app.api.v1.endpoints.graph.generate_triples",
179159
mock_generate_triples,
180160
),
181-
patch("openai.OpenAI", return_value=mock_llm_service.client),
182161
):
183162
response = client.post(
184163
"/api/v1/graph/export-triples", json=request_data
@@ -300,7 +279,6 @@ def test_export_triples_unexpected_error(
300279
patch(
301280
"app.api.v1.endpoints.graph.generate_schema", mock_generate_schema
302281
),
303-
patch("openai.OpenAI", return_value=mock_llm_service.client),
304282
):
305283
response = client.post(
306284
"/api/v1/graph/export-triples", json=request_data

0 commit comments

Comments
 (0)