1515from leakpro .fl_utils .data_utils import GiaTabularExtension
1616from leakpro .utils .seed import seed_everything
1717
18- from tabular_metrics import evaluate_reconstruction
18+ from tabular_metrics import apply_matching , compute_metrics , nearest_neighbor_distance , tab_leak_accuracy , write_results_table
1919
2020from train import train_global_model
2121from model import TabularMLP
22- from tabular import get_tabular_loaders , load_tabular_config
22+ from tabular import get_tabular_loaders , load_tabular_config , denormalize_features
2323
2424
2525
@@ -149,6 +149,92 @@ def get_set_dataloaders(cfg: dict, saved_encoder_meta: dict | None) -> tuple[dic
149149 return loaders , loaders ["data_mean" ], loaders ["data_std" ]
150150
151151
152+ def evaluate_inversion (
153+ attacker : InvertingGradients ,
154+ orig_tensor : torch .Tensor ,
155+ recon_tensor : torch .Tensor ,
156+ encoder_meta : dict ,
157+ train_loader : DataLoader ,
158+ results_path : str ,
159+ label_tensor : torch .Tensor | None ,
160+ ) -> None :
161+ """Centralized evaluation/logging for tabular inversion results."""
162+ num_cols = encoder_meta .get ("num_cols" , [])
163+ cat_cols = encoder_meta .get ("cat_cols" , [])
164+ cat_categories = encoder_meta .get ("cat_categories" , {}) or {}
165+
166+ num_count = len (num_cols )
167+ num_eps = np .full (num_count , 0.319 , dtype = np .float32 )
168+
169+ nn_batch = nearest_neighbor_distance (recon_tensor , orig_tensor , encoder_meta )
170+ logger .info (
171+ "NN distance (target batch): min=%.4f mean=%.4f median=%.4f" ,
172+ nn_batch ["min" ],
173+ nn_batch ["mean" ],
174+ nn_batch ["median" ],
175+ )
176+
177+ if hasattr (train_loader , "dataset" ):
178+ train_count = len (train_loader .dataset )
179+ if train_count * recon_tensor .shape [1 ] <= 2.0e7 :
180+ train_tensor = torch .cat ([batch [0 ] for batch in train_loader ], dim = 0 ).detach ().cpu ()
181+ nn_train = nearest_neighbor_distance (recon_tensor , train_tensor , encoder_meta )
182+ logger .info (
183+ "NN distance (train set): min=%.4f mean=%.4f median=%.4f" ,
184+ nn_train ["min" ],
185+ nn_train ["mean" ],
186+ nn_train ["median" ],
187+ )
188+ else :
189+ logger .info (
190+ "Skipped NN distance over train set (rows=%d features=%d)." ,
191+ train_count ,
192+ recon_tensor .shape [1 ],
193+ )
194+
195+ per_row_acc , _ = tab_leak_accuracy (
196+ orig_tensor ,
197+ recon_tensor ,
198+ num_cols ,
199+ cat_cols ,
200+ cat_categories ,
201+ None ,
202+ )
203+
204+ # de-normalize before writing results
205+ num_mean = encoder_meta .get ("num_mean" )
206+ num_std = encoder_meta .get ("num_std" )
207+ if num_count > 0 and num_mean is not None and num_std is not None :
208+ orig_raw = denormalize_features (orig_tensor , num_mean , num_std , num_count )
209+ recon_raw = denormalize_features (recon_tensor , num_mean , num_std , num_count )
210+ num_std_np = np .asarray (num_std ).reshape (- 1 )
211+ if num_std_np .size >= num_count :
212+ num_eps_raw = 0.319 * num_std_np [:num_count ]
213+ else :
214+ num_eps_raw = num_eps
215+ else :
216+ orig_raw = orig_tensor
217+ recon_raw = recon_tensor
218+ num_eps_raw = num_eps
219+
220+ # compute rmse and mae across the reconstruction batch on the de-normalized data
221+ metrics = compute_metrics (orig_raw , recon_raw , encoder_meta )
222+
223+ write_results_table (
224+ results_path ,
225+ orig_raw ,
226+ recon_raw ,
227+ num_cols ,
228+ cat_cols ,
229+ cat_categories ,
230+ num_eps_raw ,
231+ label_tensor ,
232+ per_row_acc ,
233+ nn_batch ,
234+ metrics ,
235+ )
236+
237+
152238def run_attack (model : TabularMLP , client_loader : DataLoader , data_mean : torch .Tensor , data_std : torch .Tensor , encoder_meta : dict , train_loader : DataLoader ) -> None :
153239
154240 # 1. Compute Public/Prior Stats
@@ -175,11 +261,15 @@ def run_attack(model: TabularMLP, client_loader: DataLoader, data_mean: torch.Te
175261
176262 # Configure InvertingGradients
177263 attack_config = InvertingConfig (
178- attack_lr = 0.01 ,
179- at_iterations = 1000 ,
180264 tv_reg = 0.01 , # This will be ignored by generic_attack_loop because is_tabular will be true
265+ attack_lr = 0.03 ,
266+ at_iterations = 10000 ,
267+ #optimizer: object = lambda : MetaSGD(),
181268 criterion = criterion ,
182- data_extension = data_extension
269+ data_extension = data_extension ,
270+ epochs = 1 ,
271+ #median_pooling = False,
272+ #top10norms = False
183273 )
184274
185275 attacker = InvertingGradients (
@@ -215,59 +305,12 @@ def run_attack(model: TabularMLP, client_loader: DataLoader, data_mean: torch.Te
215305 if result :
216306 last_result = result
217307
218- if last_result is None :
219- logger .warning ("No result produced." )
220- return
221-
222308 logger .info (f"Attack complete. Final Score: { float (last_score ):.4f} " )
223309
224- # Detailed logging
225- # InvertingGradients stores best_reconstruction as DataLoader (copied from reconstruction_loader)
226- # We need to extract the tensor
227- if attacker .best_reconstruction :
228- recon_tensor = torch .cat ([batch [0 ] for batch in attacker .best_reconstruction ], dim = 0 ).detach ().cpu ()
229- else :
230- recon_tensor = torch .zeros_like (attacker .original .cpu ())
310+ return attacker , last_result , data_mean , data_std , encoder_meta , train_loader
231311
232- orig_tensor = attacker .original .cpu ()
233-
234- # De-standardize if mean/std available
235- if data_mean is not None and data_std is not None :
236- data_mean = data_mean .cpu ()
237- data_std = data_std .cpu ()
238- # Safe inverse transform
239- orig_raw = orig_tensor * data_std + data_mean
240- recon_raw = recon_tensor * data_std + data_mean
241- else :
242- orig_raw = orig_tensor
243- recon_raw = recon_tensor
244312
245- # Show metrics
246- if last_result :
247- # last_result is GIAResults object
248- # It contains rmse_score, mae_score etc for tabular
249- logger .info (f"Final Metrics: RMSE={ last_result .rmse_score :.4f} , MAE={ last_result .mae_score :.4f} " )
250-
251- # We can also re-calculate our custom detailed metrics if desired
252- full_metrics = evaluate_reconstruction (
253- attacker .original .to (attacker .original .device ),
254- recon_tensor .to (attacker .original .device ),
255- encoder_meta ,
256- return_per_feature = True
257- )
258- agg = full_metrics ['aggregate' ]
259- logger .info (f"Detailed Metrics: Numerical Score={ agg ['numerical_score' ]:.4f} , Categorical Score={ agg ['categorical_score' ]:.4f} " )
260- print (f"FINAL_METRICS: Numerical={ agg ['numerical_score' ]:.4f} Categorical={ agg ['categorical_score' ]:.4f} " )
261-
262- # Show best row
263- errors = torch .sum (torch .abs (orig_tensor - recon_tensor ), dim = 1 )
264- idx_best = int (errors .argmin ().item ())
265- logger .info ("Best row (idx=%d) original: %s" , idx_best , orig_raw [idx_best ].numpy ())
266- logger .info ("Best row (idx=%d) reconstructed: %s" , idx_best , recon_raw [idx_best ].numpy ())
267-
268-
269-
270- def main (protocol : str = "fedsgd" ) -> None :
313+ def main (protocol : str = "fedsgd" , results_path : str = "results.txt" ) -> None :
271314 seed_everything (42 )
272315 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
273316 base_cfg , ckpt_path = load_config ()
@@ -298,11 +341,38 @@ def main(protocol: str = "fedsgd") -> None:
298341 "ckpt" if ckpt_std is not None else "loader" ,
299342 )
300343
301- run_attack (model , client_loader , data_mean , data_std , saved_encoder_meta , loaders ["train_loader" ])
344+ # run the attack on the model with the client data
345+ attacker , last_result , data_mean , data_std , encoder_meta , train_loader = run_attack (model , client_loader , data_mean , data_std , saved_encoder_meta , loaders ["train_loader" ])
346+
347+ # metrics from attack
348+ # last_result is GIAResults object. it contains rmse_score, mae_score, etc for tabular
349+ if last_result is not None :
350+ logger .info ("Final Metrics: RMSE=%.4f, MAE=%.4f" , last_result .rmse_score , last_result .mae_score )
351+
352+ # Extract tensors for evaluation
353+ if attacker .best_reconstruction :
354+ recon_tensor = torch .cat ([batch [0 ] for batch in attacker .best_reconstruction ], dim = 0 ).detach ().cpu ()
355+ else :
356+ recon_tensor = torch .zeros_like (attacker .original .cpu ())
357+ orig_tensor = attacker .original .detach ().cpu ()
358+
359+ labels = getattr (attacker , "reconstruction_labels" , None )
360+ if labels is None :
361+ label_tensor = None
362+ elif isinstance (labels , list ):
363+ label_tensor = torch .stack (labels ).view (- 1 ).detach ().cpu ()
364+ else :
365+ label_tensor = torch .as_tensor (labels ).view (- 1 ).detach ().cpu ()
366+
367+ # apply hungarian matching
368+ orig_tensor , recon_tensor , label_tensor = apply_matching (orig_tensor , recon_tensor , label_tensor )
369+ # run evaluation of the attack
370+ evaluate_inversion (attacker , orig_tensor , recon_tensor , encoder_meta , train_loader , results_path , label_tensor )
302371
303372
304373if __name__ == "__main__" :
305374 parser = argparse .ArgumentParser (description = "Tabular GIA demo" )
306375 parser .add_argument ("--protocol" , choices = ["fedavg" , "fedsgd" ], default = "fedsgd" , help = "Full client split (fedavg) or single mini-batch (fedsgd)" )
376+ parser .add_argument ("--results-path" , default = "results.txt" , help = "Results path" )
307377 parsed = parser .parse_args ()
308- main (protocol = parsed .protocol )
378+ main (protocol = parsed .protocol , results_path = parsed . results_path )
0 commit comments