33import time
44from sqlalchemy import create_engine
55from sqlalchemy .orm import sessionmaker , Session
6-
6+ from fastapi . encoders import jsonable_encoder
77from whitebox import crud , entities
88from whitebox .analytics .drift .pipelines import (
99 run_data_drift_pipeline ,
1919from whitebox .cron_tasks .shared import (
2020 get_all_models ,
2121 get_model_dataset_rows_df ,
22- get_model_inference_rows_df ,
22+ get_unused_model_inference_rows ,
23+ group_inference_rows_by_timestamp ,
24+ seperate_inference_rows ,
25+ set_inference_rows_to_used ,
26+ get_latest_drift_metrics_report ,
27+ round_timestamp ,
28+ get_used_inference_for_reusage ,
2329)
2430from whitebox .schemas .model import Model , ModelType
2531from whitebox .schemas .modelIntegrityMetric import ModelIntegrityMetricCreate
3440
3541
3642async def run_calculate_drifting_metrics_pipeline (
37- model : Model , inference_processed_df : pd .DataFrame
43+ model : Model , inference_processed_df : pd .DataFrame , timestamp : datetime
3844):
3945 """
4046 Run the pipeline to calculate the drifting metrics
@@ -67,18 +73,29 @@ async def run_calculate_drifting_metrics_pipeline(
6773 )
6874
6975 new_drifting_metric = entities .DriftingMetric (
70- timestamp = str (datetime . utcnow () ),
76+ timestamp = str (timestamp ),
7177 model_id = model .id ,
7278 concept_drift_summary = concept_drift_report ,
7379 data_drift_summary = data_drift_report ,
7480 )
7581
76- crud .drifting_metrics .create (db , obj_in = new_drifting_metric )
82+ existing_report = crud .drifting_metrics .get_first_by_filter (
83+ db = db , model_id = model .id , timestamp = timestamp
84+ )
85+ if existing_report :
86+ crud .drifting_metrics .update (
87+ db = db , db_obj = existing_report , obj_in = jsonable_encoder (new_drifting_metric )
88+ )
89+ else :
90+ crud .drifting_metrics .create (db , obj_in = new_drifting_metric )
7791 logger .info ("Drifting metrics calculated!" )
7892
7993
8094async def run_calculate_performance_metrics_pipeline (
81- model : Model , inference_processed_df : pd .DataFrame , actual_df : pd .DataFrame
95+ model : Model ,
96+ inference_processed_df : pd .DataFrame ,
97+ actual_df : pd .DataFrame ,
98+ timestamp : datetime ,
8299):
83100 """
84101 Run the pipeline to calculate the performance metrics
@@ -121,11 +138,21 @@ async def run_calculate_performance_metrics_pipeline(
121138
122139 new_performance_metric = entities .BinaryClassificationMetrics (
123140 model_id = model .id ,
124- timestamp = str (datetime . utcnow () ),
141+ timestamp = str (timestamp ),
125142 ** dict (binary_classification_metrics_report ),
126143 )
127144
128- crud .binary_classification_metrics .create (db , obj_in = new_performance_metric )
145+ existing_report = crud .binary_classification_metrics .get_first_by_filter (
146+ db = db , model_id = model .id , timestamp = timestamp
147+ )
148+ if existing_report :
149+ crud .binary_classification_metrics .update (
150+ db = db ,
151+ db_obj = existing_report ,
152+ obj_in = jsonable_encoder (new_performance_metric ),
153+ )
154+ else :
155+ crud .binary_classification_metrics .create (db , obj_in = new_performance_metric )
129156
130157 elif model .type == ModelType .multi_class :
131158 multiclass_classification_metrics_report = (
@@ -136,11 +163,21 @@ async def run_calculate_performance_metrics_pipeline(
136163
137164 new_performance_metric = entities .MultiClassificationMetrics (
138165 model_id = model .id ,
139- timestamp = str (datetime . utcnow () ),
166+ timestamp = str (timestamp ),
140167 ** dict (multiclass_classification_metrics_report ),
141168 )
142169
143- crud .multi_classification_metrics .create (db , obj_in = new_performance_metric )
170+ existing_report = crud .multi_classification_metrics .get_first_by_filter (
171+ db = db , model_id = model .id , timestamp = timestamp
172+ )
173+ if existing_report :
174+ crud .multi_classification_metrics .update (
175+ db = db ,
176+ db_obj = existing_report ,
177+ obj_in = jsonable_encoder (new_performance_metric ),
178+ )
179+ else :
180+ crud .multi_classification_metrics .create (db , obj_in = new_performance_metric )
144181
145182 elif model .type == ModelType .regression :
146183 regression_metrics_report = create_regression_evaluation_metrics_pipeline (
@@ -149,17 +186,27 @@ async def run_calculate_performance_metrics_pipeline(
149186
150187 new_performance_metric = entities .RegressionMetrics (
151188 model_id = model .id ,
152- timestamp = str (datetime . utcnow () ),
189+ timestamp = str (timestamp ),
153190 ** dict (regression_metrics_report ),
154191 )
155192
156- crud .regression_metrics .create (db , obj_in = new_performance_metric )
193+ existing_report = crud .regression_metrics .get_first_by_filter (
194+ db = db , model_id = model .id , timestamp = timestamp
195+ )
196+ if existing_report :
197+ crud .regression_metrics .update (
198+ db = db ,
199+ db_obj = existing_report ,
200+ obj_in = jsonable_encoder (new_performance_metric ),
201+ )
202+ else :
203+ crud .regression_metrics .create (db , obj_in = new_performance_metric )
157204
158205 logger .info ("Performance metrics calculated!" )
159206
160207
161208async def run_calculate_feature_metrics_pipeline (
162- model : Model , inference_processed_df : pd .DataFrame
209+ model : Model , inference_processed_df : pd .DataFrame , timestamp : datetime
163210):
164211 """
165212 Run the pipeline to calculate the feature metrics
@@ -172,11 +219,22 @@ async def run_calculate_feature_metrics_pipeline(
172219 if feature_metrics_report :
173220 new_feature_metric = ModelIntegrityMetricCreate (
174221 model_id = model .id ,
175- timestamp = str (datetime . utcnow () ),
222+ timestamp = str (timestamp ),
176223 feature_metrics = feature_metrics_report ,
177224 )
178225
179- crud .model_integrity_metrics .create (db , obj_in = new_feature_metric )
226+ existing_report = crud .model_integrity_metrics .get_first_by_filter (
227+ db = db , model_id = model .id , timestamp = timestamp
228+ )
229+ if existing_report :
230+ crud .model_integrity_metrics .update (
231+ db = db ,
232+ db_obj = existing_report ,
233+ obj_in = jsonable_encoder (new_feature_metric ),
234+ )
235+ else :
236+ crud .model_integrity_metrics .create (db , obj_in = new_feature_metric )
237+
180238 logger .info ("Feature metrics calculated!" )
181239
182240
@@ -190,24 +248,72 @@ async def run_calculate_metrics_pipeline():
190248 logger .info ("No models found! Skipping pipeline" )
191249 else :
192250 for model in models :
193- (
194- inference_processed_df ,
195- inference_nonprocessed_df ,
196- actual_df ,
197- ) = await get_model_inference_rows_df (db , model_id = model .id )
198- if inference_processed_df .empty :
251+ granularity = model .granularity
252+ granularity_amount = int (granularity [:- 1 ])
253+ granularity_type = granularity [- 1 ]
254+
255+ last_report = await get_latest_drift_metrics_report (db , model )
256+
257+ # We need to get the last report's timestamp as a base of grouping unless there's no report produced.
258+ # In this case, the base timestamp is considered the "now" rounded to the day so the intervals start from midnight
259+ # e.g. 12:00, 12:15, 12:30, 12:45 and so on if granularity is 15T.
260+ last_report_time = (
261+ last_report .timestamp
262+ if last_report
263+ else round_timestamp (datetime .utcnow (), "1D" )
264+ )
265+
266+ unused_inference_rows_in_db = await get_unused_model_inference_rows (
267+ db , model_id = model .id
268+ )
269+
270+ if len (unused_inference_rows_in_db ) == 0 :
199271 logger .info (
200- f"No inferences found for model { model .id } ! Continuing with next model..."
272+ f"No new inferences found for model { model .id } ! Continuing with next model..."
201273 )
202274 continue
203275 logger .info (f"Executing Metrics pipeline for model { model .id } ..." )
204- await run_calculate_drifting_metrics_pipeline (model , inference_processed_df )
205276
206- await run_calculate_performance_metrics_pipeline (
207- model , inference_processed_df , actual_df
277+ used_inferences = get_used_inference_for_reusage (
278+ db ,
279+ model .id ,
280+ unused_inference_rows_in_db ,
281+ last_report_time ,
282+ granularity_amount ,
283+ granularity_type ,
284+ )
285+
286+ all_inferences = unused_inference_rows_in_db + used_inferences
287+
288+ grouped_inference_rows = await group_inference_rows_by_timestamp (
289+ all_inferences ,
290+ last_report_time ,
291+ granularity_amount ,
292+ granularity_type ,
208293 )
209294
210- await run_calculate_feature_metrics_pipeline (model , inference_processed_df )
295+ for group in grouped_inference_rows :
296+ for timestamp , inference_group in group .items ():
297+ inference_rows_ids = [x .id for x in inference_group ]
298+ (
299+ inference_processed_df ,
300+ inference_nonprocessed_df ,
301+ actual_df ,
302+ ) = await seperate_inference_rows (inference_group )
303+
304+ await run_calculate_drifting_metrics_pipeline (
305+ model , inference_processed_df , timestamp
306+ )
307+
308+ await run_calculate_performance_metrics_pipeline (
309+ model , inference_processed_df , actual_df , timestamp
310+ )
311+
312+ await run_calculate_feature_metrics_pipeline (
313+ model , inference_processed_df , timestamp
314+ )
315+
316+ await set_inference_rows_to_used (db , inference_rows_ids )
211317
212318 logger .info (f"Ended Metrics pipeline for model { model .id } ..." )
213319
0 commit comments