66
77GET /models/status returns models_loaded flag, detection
88_mode (hybrid or rules), and active model metadata from
9- the database. POST /models/retrain dispatches a
9+ the database. POST /models/retrain acquires _retrain_lock
10+ (returning 409 if already running), dispatches a
1011background retraining job that loads stored ThreatEvents,
1112labels them using review_label or score thresholds
1213(SCORE_ATTACK_THRESHOLD 0.5, SCORE_NORMAL_CEILING 0.3),
1314supplements with synthetic data if below MIN_TRAINING_
1415SAMPLES (200), runs TrainingOrchestrator, and writes
1516model metadata. _fallback_synthetic spawns a subprocess
16- CLI train command when no real events exist
17+ CLI train command with lifecycle tracking via
18+ _synthetic_process
1719
1820Connects to:
1921 config.py - settings.model_dir, ensemble
2527 cli/main - _write_metadata
2628"""
2729
30+ import asyncio
2831import logging
32+ import subprocess
2933import uuid
3034
31- from fastapi import APIRouter , BackgroundTasks , Request
35+ from fastapi import APIRouter , BackgroundTasks , Request , Response
36+ from fastapi .responses import JSONResponse
3237from sqlalchemy import func , select
3338from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker
3439
4651SYNTHETIC_SUPPLEMENT_NORMAL = 500
4752SYNTHETIC_SUPPLEMENT_ATTACK = 250
4853
54+ _retrain_lock = asyncio .Lock ()
55+ _synthetic_process : subprocess .Popen [bytes ] | None = None
56+
4957
5058@router .get ("/status" )
5159async def model_status (request : Request ) -> dict [str , object ]:
@@ -68,15 +76,21 @@ async def model_status(request: Request) -> dict[str, object]:
6876 }
6977
7078
71- @router .post ("/retrain" , status_code = 202 )
79+ @router .post ("/retrain" , status_code = 202 , response_model = None )
7280async def retrain (
7381 request : Request ,
7482 background_tasks : BackgroundTasks ,
75- ) -> dict [str , object ]:
83+ ) -> dict [str , object ] | Response :
7684 """
7785 Dispatch a model retraining job using real stored
7886 threat events supplemented with synthetic data
7987 """
88+ if _retrain_lock .locked ():
89+ return JSONResponse (
90+ status_code = 409 ,
91+ content = {"status" : "conflict" , "job_id" : "" },
92+ )
93+
8094 session_factory = getattr (request .app .state , "session_factory" , None )
8195 if session_factory is None :
8296 return {"status" : "error" , "job_id" : "" }
@@ -99,129 +113,155 @@ async def _retrain_from_db(
99113 supplement with synthetic data if needed, and run
100114 the full training pipeline
101115 """
102- import asyncio
103116 import dataclasses
104117 from pathlib import Path
105118
106119 import numpy as np
107120
108121 from ml .orchestrator import TrainingOrchestrator
109122
110- logger .info ("Retrain job %s: loading stored events" , job_id )
111-
112- async with session_factory () as session :
113- count = (await session .execute (
114- select (func .count ()).select_from (ThreatEvent )
115- )).scalar_one ()
116-
117- if count == 0 :
118- logger .warning (
119- "Retrain job %s: no stored events, using synthetic only" ,
120- job_id ,
121- )
122- _fallback_synthetic (job_id )
123- return
123+ async with _retrain_lock :
124+ logger .info ("Retrain job %s: loading stored events" , job_id )
124125
125- rows = (await session .execute (
126- select (ThreatEvent )
127- )).scalars ().all ()
126+ async with session_factory () as session :
127+ count = (await session .execute (
128+ select (func .count ()).select_from (ThreatEvent )
129+ )).scalar_one ()
130+
131+ if count == 0 :
132+ logger .warning (
133+ "Retrain job %s: no stored events, "
134+ "using synthetic only" ,
135+ job_id ,
136+ )
137+ _fallback_synthetic (job_id )
138+ return
139+
140+ rows = (await session .execute (
141+ select (ThreatEvent )
142+ )).scalars ().all ()
143+
144+ vectors : list [list [float ]] = []
145+ labels : list [int ] = []
146+
147+ for event in rows :
148+ if not event .feature_vector :
149+ continue
150+
151+ if event .reviewed and event .review_label :
152+ label = (
153+ 1 if event .review_label == "true_positive"
154+ else 0
155+ )
156+ elif event .threat_score >= SCORE_ATTACK_THRESHOLD :
157+ label = 1
158+ elif event .threat_score < SCORE_NORMAL_CEILING :
159+ label = 0
160+ else :
161+ continue
162+
163+ vectors .append (event .feature_vector )
164+ labels .append (label )
128165
129- vectors : list [list [float ]] = []
130- labels : list [int ] = []
166+ logger .info (
167+ "Retrain job %s: %d usable events from DB "
168+ "(normal=%d, attack=%d)" ,
169+ job_id ,
170+ len (vectors ),
171+ labels .count (0 ),
172+ labels .count (1 ),
173+ )
131174
132- for event in rows :
133- if not event .feature_vector :
134- continue
175+ from ml .synthetic import generate_mixed_dataset
135176
136- if event .reviewed and event .review_label :
137- label = 1 if event .review_label == "true_positive" else 0
138- elif event .threat_score >= SCORE_ATTACK_THRESHOLD :
139- label = 1
140- elif event .threat_score < SCORE_NORMAL_CEILING :
141- label = 0
177+ if len (vectors ) < MIN_TRAINING_SAMPLES :
178+ syn_X , syn_y = generate_mixed_dataset (
179+ SYNTHETIC_SUPPLEMENT_NORMAL ,
180+ SYNTHETIC_SUPPLEMENT_ATTACK ,
181+ )
182+ X = np .concatenate ([
183+ np .array (vectors , dtype = np .float32 ),
184+ syn_X ,
185+ ]) if vectors else syn_X
186+ y = np .concatenate ([
187+ np .array (labels , dtype = np .int32 ),
188+ syn_y ,
189+ ]) if labels else syn_y
190+ logger .info (
191+ "Retrain job %s: supplemented with "
192+ "%d synthetic samples" ,
193+ job_id ,
194+ len (syn_X ),
195+ )
142196 else :
143- continue
144-
145- vectors .append (event .feature_vector )
146- labels .append (label )
147-
148- logger .info (
149- "Retrain job %s: %d usable events from DB "
150- "(normal=%d, attack=%d)" ,
151- job_id ,
152- len (vectors ),
153- labels .count (0 ),
154- labels .count (1 ),
155- )
156-
157- from ml .synthetic import generate_mixed_dataset
158-
159- if len (vectors ) < MIN_TRAINING_SAMPLES :
160- syn_X , syn_y = generate_mixed_dataset (
161- SYNTHETIC_SUPPLEMENT_NORMAL ,
162- SYNTHETIC_SUPPLEMENT_ATTACK ,
197+ X = np .array (vectors , dtype = np .float32 )
198+ y = np .array (labels , dtype = np .int32 )
199+
200+ output_dir = Path (settings .model_dir )
201+ loop = asyncio .get_running_loop ()
202+ result = await loop .run_in_executor (
203+ None ,
204+ lambda : TrainingOrchestrator (
205+ output_dir = output_dir ,
206+ ).run (X , y ),
163207 )
164- X = np .concatenate ([
165- np .array (vectors , dtype = np .float32 ),
166- syn_X ,
167- ]) if vectors else syn_X
168- y = np .concatenate ([
169- np .array (labels , dtype = np .int32 ),
170- syn_y ,
171- ]) if labels else syn_y
208+
172209 logger .info (
173- "Retrain job %s: supplemented with %d synthetic samples " ,
210+ "Retrain job %s complete: passed_gates=%s " ,
174211 job_id ,
175- len ( syn_X ) ,
212+ result . passed_gates ,
176213 )
177- else :
178- X = np .array (vectors , dtype = np .float32 )
179- y = np .array (labels , dtype = np .int32 )
180-
181- output_dir = Path (settings .model_dir )
182- loop = asyncio .get_running_loop ()
183- result = await loop .run_in_executor (
184- None ,
185- lambda : TrainingOrchestrator (output_dir = output_dir ).run (X , y ),
186- )
187-
188- logger .info (
189- "Retrain job %s complete: passed_gates=%s" ,
190- job_id ,
191- result .passed_gates ,
192- )
193214
194- try :
195- from cli .main import _write_metadata
215+ try :
216+ from cli .main import _write_metadata
196217
197- metrics : dict [str , object ] = (
198- dataclasses .asdict (result .ensemble_metrics )
199- if result .ensemble_metrics else {}
200- )
201- await _write_metadata (
202- output_dir ,
203- len (X ),
204- metrics ,
205- result .mlflow_run_id ,
206- result .ae_metrics .get ("ae_threshold" ),
207- )
208- except Exception :
209- logger .exception (
210- "Retrain job %s: failed to write metadata" ,
211- job_id ,
212- )
218+ metrics : dict [str , object ] = (
219+ dataclasses .asdict (result .ensemble_metrics )
220+ if result .ensemble_metrics else {}
221+ )
222+ await _write_metadata (
223+ output_dir ,
224+ len (X ),
225+ metrics ,
226+ result .mlflow_run_id ,
227+ result .ae_metrics .get ("ae_threshold" ),
228+ )
229+ except Exception :
230+ logger .exception (
231+ "Retrain job %s: failed to write metadata" ,
232+ job_id ,
233+ )
213234
214235
215236def _fallback_synthetic (job_id : str ) -> None :
216237 """
217238 Run training with synthetic data only when no real
218239 events exist
219240 """
220- import subprocess
241+ global _synthetic_process # noqa: PLW0603
221242 import sys
222243
223- logger .info ("Retrain job %s: falling back to synthetic training" , job_id )
224- subprocess .Popen (
244+ if _synthetic_process is not None :
245+ if _synthetic_process .poll () is None :
246+ logger .info (
247+ "Retrain job %s: synthetic training already "
248+ "running (pid=%d)" ,
249+ job_id ,
250+ _synthetic_process .pid ,
251+ )
252+ return
253+ rc = _synthetic_process .returncode
254+ if rc != 0 :
255+ logger .warning (
256+ "Previous synthetic training exited with %d" ,
257+ rc ,
258+ )
259+
260+ logger .info (
261+ "Retrain job %s: falling back to synthetic training" ,
262+ job_id ,
263+ )
264+ _synthetic_process = subprocess .Popen (
225265 [
226266 sys .executable ,
227267 "-m" ,
0 commit comments