|
1 | | -import sys |
2 | 1 | import os |
3 | | -from unittest.mock import patch, MagicMock |
| 2 | +import sys |
| 3 | +from unittest.mock import MagicMock, patch |
| 4 | + |
4 | 5 |
|
5 | 6 | # Add src to sys.path for imports |
6 | | -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) |
| 7 | +sys.path.insert(0, os.path.abspath(os.path.join( |
| 8 | + os.path.dirname(__file__), '..', 'src'))) |
| 9 | + |
7 | 10 |
|
8 | 11 | # Mock mlflow.sklearn.load_model before importing your app |
9 | 12 | mock_model = MagicMock() |
|
12 | 15 | patcher = patch('mlflow.sklearn.load_model', return_value=mock_model) |
13 | 16 | patcher.start() |
14 | 17 |
|
15 | | -from api.main import app # import AFTER patching |
16 | | - |
17 | 18 | from fastapi.testclient import TestClient |
| 19 | +from api.main import app # import AFTER patching |
18 | 20 |
|
19 | 21 | client = TestClient(app) |
20 | 22 |
|
| 23 | + |
21 | 24 | def test_predict(): |
22 | 25 | sample_data = { |
23 | 26 | "Recency": 1, |
@@ -45,12 +48,13 @@ def test_predict(): |
45 | 48 | "ProductCategory_ticket": False, |
46 | 49 | "ProductCategory_transport": False, |
47 | 50 | "ProductCategory_tv": False, |
48 | | - "ProductCategory_utility_bill": False |
| 51 | + "ProductCategory_utility_bill": False, |
49 | 52 | } |
50 | 53 |
|
51 | 54 | response = client.post("/predict", json=sample_data) |
52 | 55 | assert response.status_code == 200 |
53 | 56 | assert "risk_probability" in response.json() |
54 | 57 | assert abs(response.json()["risk_probability"] - 0.7) < 1e-6 |
55 | 58 |
|
| 59 | + |
56 | 60 | patcher.stop() |
0 commit comments