Skip to content

Commit 83e3b2b

Browse files
Merge pull request #184 from CarterPerez-dev/chore/ai-threat-detection-finish
chore complete + add pre commit chnages
2 parents b6f75db + dc6567a commit 83e3b2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+6000
-275
lines changed

PROJECTS/advanced/ai-threat-detection/backend/app/api/health.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
Health and readiness probe endpoints for container
66
orchestration
77
8-
GET /health returns liveness status with uptime_seconds
9-
and pipeline_running flag. GET /ready checks database
8+
GET /health returns liveness status with uptime_seconds,
9+
pipeline_running flag, and per-stage pipeline_stats
10+
counters (parsed/enriched/scored/dispatched with error
11+
counts). GET /ready checks database
1012
connectivity (SELECT 1) and Redis ping, reports
1113
models_loaded status, and returns 503 if any dependency
1214
is down. Both endpoints read from app.state set during
@@ -34,10 +36,12 @@ async def health(request: Request) -> dict[str, object]:
3436
Liveness probe — returns 200 if the process is alive.
3537
"""
3638
uptime = time.monotonic() - request.app.state.startup_time
39+
pipeline = getattr(request.app.state, "pipeline", None)
3740
return {
3841
"status": "healthy",
3942
"uptime_seconds": round(uptime, 2),
4043
"pipeline_running": request.app.state.pipeline_running,
44+
"pipeline_stats": pipeline.stats if pipeline else {},
4145
}
4246

4347

PROJECTS/advanced/ai-threat-detection/backend/app/api/models_api.py

Lines changed: 141 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
77
GET /models/status returns models_loaded flag, detection
88
_mode (hybrid or rules), and active model metadata from
9-
the database. POST /models/retrain dispatches a
9+
the database. POST /models/retrain acquires _retrain_lock
10+
(returning 409 if already running), dispatches a
1011
background retraining job that loads stored ThreatEvents,
1112
labels them using review_label or score thresholds
1213
(SCORE_ATTACK_THRESHOLD 0.5, SCORE_NORMAL_CEILING 0.3),
1314
supplements with synthetic data if below MIN_TRAINING_
1415
SAMPLES (200), runs TrainingOrchestrator, and writes
1516
model metadata. _fallback_synthetic spawns a subprocess
16-
CLI train command when no real events exist
17+
CLI train command with lifecycle tracking via
18+
_synthetic_process
1719
1820
Connects to:
1921
config.py - settings.model_dir, ensemble
@@ -25,10 +27,13 @@
2527
cli/main - _write_metadata
2628
"""
2729

30+
import asyncio
2831
import logging
32+
import subprocess
2933
import uuid
3034

31-
from fastapi import APIRouter, BackgroundTasks, Request
35+
from fastapi import APIRouter, BackgroundTasks, Request, Response
36+
from fastapi.responses import JSONResponse
3237
from sqlalchemy import func, select
3338
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
3439

@@ -46,6 +51,9 @@
4651
SYNTHETIC_SUPPLEMENT_NORMAL = 500
4752
SYNTHETIC_SUPPLEMENT_ATTACK = 250
4853

54+
_retrain_lock = asyncio.Lock()
55+
_synthetic_process: subprocess.Popen[bytes] | None = None
56+
4957

5058
@router.get("/status")
5159
async def model_status(request: Request) -> dict[str, object]:
@@ -68,15 +76,21 @@ async def model_status(request: Request) -> dict[str, object]:
6876
}
6977

7078

71-
@router.post("/retrain", status_code=202)
79+
@router.post("/retrain", status_code=202, response_model=None)
7280
async def retrain(
7381
request: Request,
7482
background_tasks: BackgroundTasks,
75-
) -> dict[str, object]:
83+
) -> dict[str, object] | Response:
7684
"""
7785
Dispatch a model retraining job using real stored
7886
threat events supplemented with synthetic data
7987
"""
88+
if _retrain_lock.locked():
89+
return JSONResponse(
90+
status_code=409,
91+
content={"status": "conflict", "job_id": ""},
92+
)
93+
8094
session_factory = getattr(request.app.state, "session_factory", None)
8195
if session_factory is None:
8296
return {"status": "error", "job_id": ""}
@@ -99,129 +113,155 @@ async def _retrain_from_db(
99113
supplement with synthetic data if needed, and run
100114
the full training pipeline
101115
"""
102-
import asyncio
103116
import dataclasses
104117
from pathlib import Path
105118

106119
import numpy as np
107120

108121
from ml.orchestrator import TrainingOrchestrator
109122

110-
logger.info("Retrain job %s: loading stored events", job_id)
111-
112-
async with session_factory() as session:
113-
count = (await session.execute(
114-
select(func.count()).select_from(ThreatEvent)
115-
)).scalar_one()
116-
117-
if count == 0:
118-
logger.warning(
119-
"Retrain job %s: no stored events, using synthetic only",
120-
job_id,
121-
)
122-
_fallback_synthetic(job_id)
123-
return
123+
async with _retrain_lock:
124+
logger.info("Retrain job %s: loading stored events", job_id)
124125

125-
rows = (await session.execute(
126-
select(ThreatEvent)
127-
)).scalars().all()
126+
async with session_factory() as session:
127+
count = (await session.execute(
128+
select(func.count()).select_from(ThreatEvent)
129+
)).scalar_one()
130+
131+
if count == 0:
132+
logger.warning(
133+
"Retrain job %s: no stored events, "
134+
"using synthetic only",
135+
job_id,
136+
)
137+
_fallback_synthetic(job_id)
138+
return
139+
140+
rows = (await session.execute(
141+
select(ThreatEvent)
142+
)).scalars().all()
143+
144+
vectors: list[list[float]] = []
145+
labels: list[int] = []
146+
147+
for event in rows:
148+
if not event.feature_vector:
149+
continue
150+
151+
if event.reviewed and event.review_label:
152+
label = (
153+
1 if event.review_label == "true_positive"
154+
else 0
155+
)
156+
elif event.threat_score >= SCORE_ATTACK_THRESHOLD:
157+
label = 1
158+
elif event.threat_score < SCORE_NORMAL_CEILING:
159+
label = 0
160+
else:
161+
continue
162+
163+
vectors.append(event.feature_vector)
164+
labels.append(label)
128165

129-
vectors: list[list[float]] = []
130-
labels: list[int] = []
166+
logger.info(
167+
"Retrain job %s: %d usable events from DB "
168+
"(normal=%d, attack=%d)",
169+
job_id,
170+
len(vectors),
171+
labels.count(0),
172+
labels.count(1),
173+
)
131174

132-
for event in rows:
133-
if not event.feature_vector:
134-
continue
175+
from ml.synthetic import generate_mixed_dataset
135176

136-
if event.reviewed and event.review_label:
137-
label = 1 if event.review_label == "true_positive" else 0
138-
elif event.threat_score >= SCORE_ATTACK_THRESHOLD:
139-
label = 1
140-
elif event.threat_score < SCORE_NORMAL_CEILING:
141-
label = 0
177+
if len(vectors) < MIN_TRAINING_SAMPLES:
178+
syn_X, syn_y = generate_mixed_dataset(
179+
SYNTHETIC_SUPPLEMENT_NORMAL,
180+
SYNTHETIC_SUPPLEMENT_ATTACK,
181+
)
182+
X = np.concatenate([
183+
np.array(vectors, dtype=np.float32),
184+
syn_X,
185+
]) if vectors else syn_X
186+
y = np.concatenate([
187+
np.array(labels, dtype=np.int32),
188+
syn_y,
189+
]) if labels else syn_y
190+
logger.info(
191+
"Retrain job %s: supplemented with "
192+
"%d synthetic samples",
193+
job_id,
194+
len(syn_X),
195+
)
142196
else:
143-
continue
144-
145-
vectors.append(event.feature_vector)
146-
labels.append(label)
147-
148-
logger.info(
149-
"Retrain job %s: %d usable events from DB "
150-
"(normal=%d, attack=%d)",
151-
job_id,
152-
len(vectors),
153-
labels.count(0),
154-
labels.count(1),
155-
)
156-
157-
from ml.synthetic import generate_mixed_dataset
158-
159-
if len(vectors) < MIN_TRAINING_SAMPLES:
160-
syn_X, syn_y = generate_mixed_dataset(
161-
SYNTHETIC_SUPPLEMENT_NORMAL,
162-
SYNTHETIC_SUPPLEMENT_ATTACK,
197+
X = np.array(vectors, dtype=np.float32)
198+
y = np.array(labels, dtype=np.int32)
199+
200+
output_dir = Path(settings.model_dir)
201+
loop = asyncio.get_running_loop()
202+
result = await loop.run_in_executor(
203+
None,
204+
lambda: TrainingOrchestrator(
205+
output_dir=output_dir,
206+
).run(X, y),
163207
)
164-
X = np.concatenate([
165-
np.array(vectors, dtype=np.float32),
166-
syn_X,
167-
]) if vectors else syn_X
168-
y = np.concatenate([
169-
np.array(labels, dtype=np.int32),
170-
syn_y,
171-
]) if labels else syn_y
208+
172209
logger.info(
173-
"Retrain job %s: supplemented with %d synthetic samples",
210+
"Retrain job %s complete: passed_gates=%s",
174211
job_id,
175-
len(syn_X),
212+
result.passed_gates,
176213
)
177-
else:
178-
X = np.array(vectors, dtype=np.float32)
179-
y = np.array(labels, dtype=np.int32)
180-
181-
output_dir = Path(settings.model_dir)
182-
loop = asyncio.get_running_loop()
183-
result = await loop.run_in_executor(
184-
None,
185-
lambda: TrainingOrchestrator(output_dir=output_dir).run(X, y),
186-
)
187-
188-
logger.info(
189-
"Retrain job %s complete: passed_gates=%s",
190-
job_id,
191-
result.passed_gates,
192-
)
193214

194-
try:
195-
from cli.main import _write_metadata
215+
try:
216+
from cli.main import _write_metadata
196217

197-
metrics: dict[str, object] = (
198-
dataclasses.asdict(result.ensemble_metrics)
199-
if result.ensemble_metrics else {}
200-
)
201-
await _write_metadata(
202-
output_dir,
203-
len(X),
204-
metrics,
205-
result.mlflow_run_id,
206-
result.ae_metrics.get("ae_threshold"),
207-
)
208-
except Exception:
209-
logger.exception(
210-
"Retrain job %s: failed to write metadata",
211-
job_id,
212-
)
218+
metrics: dict[str, object] = (
219+
dataclasses.asdict(result.ensemble_metrics)
220+
if result.ensemble_metrics else {}
221+
)
222+
await _write_metadata(
223+
output_dir,
224+
len(X),
225+
metrics,
226+
result.mlflow_run_id,
227+
result.ae_metrics.get("ae_threshold"),
228+
)
229+
except Exception:
230+
logger.exception(
231+
"Retrain job %s: failed to write metadata",
232+
job_id,
233+
)
213234

214235

215236
def _fallback_synthetic(job_id: str) -> None:
216237
"""
217238
Run training with synthetic data only when no real
218239
events exist
219240
"""
220-
import subprocess
241+
global _synthetic_process # noqa: PLW0603
221242
import sys
222243

223-
logger.info("Retrain job %s: falling back to synthetic training", job_id)
224-
subprocess.Popen(
244+
if _synthetic_process is not None:
245+
if _synthetic_process.poll() is None:
246+
logger.info(
247+
"Retrain job %s: synthetic training already "
248+
"running (pid=%d)",
249+
job_id,
250+
_synthetic_process.pid,
251+
)
252+
return
253+
rc = _synthetic_process.returncode
254+
if rc != 0:
255+
logger.warning(
256+
"Previous synthetic training exited with %d",
257+
rc,
258+
)
259+
260+
logger.info(
261+
"Retrain job %s: falling back to synthetic training",
262+
job_id,
263+
)
264+
_synthetic_process = subprocess.Popen(
225265
[
226266
sys.executable,
227267
"-m",

PROJECTS/advanced/ai-threat-detection/backend/app/config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
settings (size 32, timeout 50ms), and ML configuration
1414
(model_dir, detection_mode, ensemble weights for
1515
autoencoder/random-forest/isolation-forest at 0.40/0.40
16-
/0.20, ae_threshold_percentile 99.5, MLflow tracking
17-
URI). Exports a module-level singleton settings instance
16+
/0.20 with model_validator enforcing sum-to-1.0,
17+
ae_threshold_percentile 99.5, MLflow tracking URI).
18+
Exports a module-level singleton settings instance
1819
1920
Connects to:
2021
factory.py - consumed in lifespan and create_app
@@ -24,6 +25,9 @@
2425
core/enrichment/ - geoip_db_path
2526
"""
2627

28+
from typing import Self
29+
30+
from pydantic import model_validator
2731
from pydantic_settings import BaseSettings, SettingsConfigDict
2832

2933

@@ -70,5 +74,21 @@ class Settings(BaseSettings):
7074
ae_threshold_percentile: float = 99.5
7175
mlflow_tracking_uri: str = "file:./mlruns"
7276

77+
@model_validator(mode="after")
78+
def _check_ensemble_weights(self) -> Self:
79+
"""
80+
Validate that ensemble weights sum to 1.0
81+
"""
82+
total = (
83+
self.ensemble_weight_ae
84+
+ self.ensemble_weight_rf
85+
+ self.ensemble_weight_if
86+
)
87+
if abs(total - 1.0) > 1e-6:
88+
raise ValueError(
89+
f"Ensemble weights must sum to 1.0, got {total:.6f}"
90+
)
91+
return self
92+
7393

7494
settings = Settings()

0 commit comments

Comments
 (0)