|
1 | | -import os |
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | +from contextlib import asynccontextmanager |
2 | 4 |
|
3 | 5 | import uvicorn |
4 | 6 | from fastapi import FastAPI, Request, Response |
| 7 | +from fastapi.middleware.cors import CORSMiddleware |
5 | 8 | from fastapi.responses import JSONResponse |
| 9 | +from fastapi_utils.tasks import repeat_every |
6 | 10 | from prometheus_client import CONTENT_TYPE_LATEST, generate_latest |
7 | | -from fastapi.middleware.cors import CORSMiddleware |
8 | | -import logging |
9 | 11 |
|
10 | 12 | # Endpoint routers |
11 | 13 | from src.endpoints.consumer.consumer_endpoint import router as consumer_router |
12 | | -from src.endpoints.metrics.fairness.group.dir import router as dir_router |
| 14 | +from src.endpoints.data.data_download import router as data_download_router |
13 | 15 | from src.endpoints.data.data_upload import router as data_upload_router |
14 | 16 |
|
| 17 | +# from src.endpoints.explainers import router as explainers_router |
| 18 | +from src.endpoints.explainers.global_explainer import router as explainers_global_router |
| 19 | +from src.endpoints.explainers.local_explainer import router as explainers_local_router |
| 20 | +from src.endpoints.metadata import router as metadata_router |
| 21 | + |
15 | 22 | # from src.endpoints.drift_metrics import router as drift_metrics_router |
16 | 23 | from src.endpoints.metrics.drift.approx_ks_test import ( |
17 | 24 | router as drift_approx_ks_test_router, |
18 | 25 | ) |
19 | 26 | from src.endpoints.metrics.drift.fourier_mmd import router as drift_fourier_mmd_router |
20 | 27 | from src.endpoints.metrics.drift.ks_test import router as drift_ks_test_router |
21 | 28 | from src.endpoints.metrics.drift.meanshift import router as drift_meanshift_router |
22 | | - |
23 | | -# from src.endpoints.explainers import router as explainers_router |
24 | | -from src.endpoints.explainers.global_explainer import router as explainers_global_router |
25 | | -from src.endpoints.explainers.local_explainer import router as explainers_local_router |
| 29 | +from src.endpoints.metrics.fairness.group.dir import router as dir_router |
26 | 30 | from src.endpoints.metrics.fairness.group.spd import router as spd_router |
27 | 31 | from src.endpoints.metrics.identity.identity_endpoint import router as identity_router |
28 | | -from src.endpoints.metadata import router as metadata_router |
29 | 32 | from src.endpoints.metrics.metrics_info import router as metrics_info_router |
30 | | -from src.endpoints.data.data_download import router as data_download_router |
| 33 | + |
| 34 | +from src.service.prometheus.prometheus_scheduler import PrometheusScheduler |
31 | 35 |
|
32 | 36 | try: |
33 | 37 | from src.endpoints.evaluation.lm_evaluation_harness import ( |
|
44 | 48 | ) |
45 | 49 | logger = logging.getLogger(__name__) |
46 | 50 |
|
| 51 | +prometheus_scheduler = PrometheusScheduler() |
| 52 | + |
| 53 | + |
| 54 | +@repeat_every( |
| 55 | + seconds=prometheus_scheduler.service_config.get("metrics_schedule", 30), |
| 56 | + logger=logger, |
| 57 | + raise_exceptions=False, |
| 58 | +) |
| 59 | +async def schedule_metrics_calculation(): |
| 60 | + await prometheus_scheduler.calculate() |
| 61 | + |
| 62 | + |
| 63 | +@asynccontextmanager |
| 64 | +async def lifespan(app: FastAPI): |
| 65 | + task = asyncio.create_task(schedule_metrics_calculation()) |
| 66 | + |
| 67 | + yield |
| 68 | + |
| 69 | + task.cancel() |
| 70 | + try: |
| 71 | + await task |
| 72 | + except asyncio.CancelledError: |
| 73 | + logger.info("Prometheus metrics calculation task cancelled during shutdown") |
| 74 | + |
| 75 | + |
47 | 76 | app = FastAPI( |
48 | 77 | title="TrustyAI Service API", |
49 | 78 | version="1.0.0rc0", |
50 | 79 | description="TrustyAI Service API", |
| 80 | + lifespan=lifespan, |
51 | 81 | ) |
52 | 82 |
|
53 | 83 | # CORS |
|
0 commit comments