Skip to content

Commit 249670c

Browse files
authored
Merge pull request #2 from Maxengw/ivo/traintabular_modifiedgiaattack
Ivo/traintabular modifiedgiaattack
2 parents 3065171 + a54733a commit 249670c

File tree

12 files changed

+618
-193
lines changed

12 files changed

+618
-193
lines changed
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,18 @@
11
target: income
2-
missing_values: "?"
2+
missing_values: "?"
3+
numerical_columns:
4+
- age
5+
- fnlwgt
6+
- education-num
7+
- capital-gain
8+
- capital-loss
9+
- hours-per-week
10+
categorical_columns:
11+
- workclass
12+
- education
13+
- marital-status
14+
- occupation
15+
- relationship
16+
- race
17+
- sex
18+
- native-country
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
11
target: y
2-
missing_values: unkown
2+
missing_values: unkown
3+
categorical_columns:
4+
- job
5+
- marital
6+
- education
7+
- default
8+
- housing
9+
- loan
10+
- contact
11+
- month
12+
- poutcome
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
target: 54
22
no_header: true
3-
missing_values: ""
3+
missing_values: ""
4+
categorical_columns: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
target: class
22
missing_values: ""
33
has_test_split: data/multiclass/statlog+shuttle/statlog+shuttle.test.csv
4+
categorical_columns: []
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
target: median_house_value
2-
missing_values: ""
2+
missing_values: ""
3+
categorical_columns: []
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
target: 0
22
no_header: true
3-
missing_values: ""
3+
missing_values: ""
4+
categorical_columns: []

examples/gia/GIA_base_running_tabular/main.py

Lines changed: 126 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from leakpro.fl_utils.data_utils import GiaTabularExtension
1616
from leakpro.utils.seed import seed_everything
1717

18-
from tabular_metrics import evaluate_reconstruction
18+
from tabular_metrics import apply_matching, compute_metrics, nearest_neighbor_distance, tab_leak_accuracy, write_results_table
1919

2020
from train import train_global_model
2121
from model import TabularMLP
22-
from tabular import get_tabular_loaders, load_tabular_config
22+
from tabular import get_tabular_loaders, load_tabular_config, denormalize_features
2323

2424

2525

@@ -149,6 +149,92 @@ def get_set_dataloaders(cfg: dict, saved_encoder_meta: dict | None) -> tuple[dic
149149
return loaders, loaders["data_mean"], loaders["data_std"]
150150

151151

152+
def evaluate_inversion(
153+
attacker: InvertingGradients,
154+
orig_tensor: torch.Tensor,
155+
recon_tensor: torch.Tensor,
156+
encoder_meta: dict,
157+
train_loader: DataLoader,
158+
results_path: str,
159+
label_tensor: torch.Tensor | None,
160+
) -> None:
161+
"""Centralized evaluation/logging for tabular inversion results."""
162+
num_cols = encoder_meta.get("num_cols", [])
163+
cat_cols = encoder_meta.get("cat_cols", [])
164+
cat_categories = encoder_meta.get("cat_categories", {}) or {}
165+
166+
num_count = len(num_cols)
167+
num_eps = np.full(num_count, 0.319, dtype=np.float32)
168+
169+
nn_batch = nearest_neighbor_distance(recon_tensor, orig_tensor, encoder_meta)
170+
logger.info(
171+
"NN distance (target batch): min=%.4f mean=%.4f median=%.4f",
172+
nn_batch["min"],
173+
nn_batch["mean"],
174+
nn_batch["median"],
175+
)
176+
177+
if hasattr(train_loader, "dataset"):
178+
train_count = len(train_loader.dataset)
179+
if train_count * recon_tensor.shape[1] <= 2.0e7:
180+
train_tensor = torch.cat([batch[0] for batch in train_loader], dim=0).detach().cpu()
181+
nn_train = nearest_neighbor_distance(recon_tensor, train_tensor, encoder_meta)
182+
logger.info(
183+
"NN distance (train set): min=%.4f mean=%.4f median=%.4f",
184+
nn_train["min"],
185+
nn_train["mean"],
186+
nn_train["median"],
187+
)
188+
else:
189+
logger.info(
190+
"Skipped NN distance over train set (rows=%d features=%d).",
191+
train_count,
192+
recon_tensor.shape[1],
193+
)
194+
195+
per_row_acc, _ = tab_leak_accuracy(
196+
orig_tensor,
197+
recon_tensor,
198+
num_cols,
199+
cat_cols,
200+
cat_categories,
201+
None,
202+
)
203+
204+
# de-normalize before writing results
205+
num_mean = encoder_meta.get("num_mean")
206+
num_std = encoder_meta.get("num_std")
207+
if num_count > 0 and num_mean is not None and num_std is not None:
208+
orig_raw = denormalize_features(orig_tensor, num_mean, num_std, num_count)
209+
recon_raw = denormalize_features(recon_tensor, num_mean, num_std, num_count)
210+
num_std_np = np.asarray(num_std).reshape(-1)
211+
if num_std_np.size >= num_count:
212+
num_eps_raw = 0.319 * num_std_np[:num_count]
213+
else:
214+
num_eps_raw = num_eps
215+
else:
216+
orig_raw = orig_tensor
217+
recon_raw = recon_tensor
218+
num_eps_raw = num_eps
219+
220+
# compute rmse and mae across the reconstruction batch on the de-normalized data
221+
metrics = compute_metrics(orig_raw, recon_raw, encoder_meta)
222+
223+
write_results_table(
224+
results_path,
225+
orig_raw,
226+
recon_raw,
227+
num_cols,
228+
cat_cols,
229+
cat_categories,
230+
num_eps_raw,
231+
label_tensor,
232+
per_row_acc,
233+
nn_batch,
234+
metrics,
235+
)
236+
237+
152238
def run_attack(model: TabularMLP, client_loader: DataLoader, data_mean: torch.Tensor, data_std: torch.Tensor, encoder_meta: dict, train_loader: DataLoader) -> None:
153239

154240
# 1. Compute Public/Prior Stats
@@ -175,11 +261,15 @@ def run_attack(model: TabularMLP, client_loader: DataLoader, data_mean: torch.Te
175261

176262
# Configure InvertingGradients
177263
attack_config = InvertingConfig(
178-
attack_lr=0.01,
179-
at_iterations=1000,
180264
tv_reg=0.01, # This will be ignored by generic_attack_loop because is_tabular will be true
265+
attack_lr=0.03,
266+
at_iterations=10000,
267+
#optimizer: object = lambda : MetaSGD(),
181268
criterion=criterion,
182-
data_extension=data_extension
269+
data_extension=data_extension,
270+
epochs=1,
271+
#median_pooling = False,
272+
#top10norms = False
183273
)
184274

185275
attacker = InvertingGradients(
@@ -215,59 +305,12 @@ def run_attack(model: TabularMLP, client_loader: DataLoader, data_mean: torch.Te
215305
if result:
216306
last_result = result
217307

218-
if last_result is None:
219-
logger.warning("No result produced.")
220-
return
221-
222308
logger.info(f"Attack complete. Final Score: {float(last_score):.4f}")
223309

224-
# Detailed logging
225-
# InvertingGradients stores best_reconstruction as DataLoader (copied from reconstruction_loader)
226-
# We need to extract the tensor
227-
if attacker.best_reconstruction:
228-
recon_tensor = torch.cat([batch[0] for batch in attacker.best_reconstruction], dim=0).detach().cpu()
229-
else:
230-
recon_tensor = torch.zeros_like(attacker.original.cpu())
310+
return attacker, last_result, data_mean, data_std, encoder_meta, train_loader
231311

232-
orig_tensor = attacker.original.cpu()
233-
234-
# De-standardize if mean/std available
235-
if data_mean is not None and data_std is not None:
236-
data_mean = data_mean.cpu()
237-
data_std = data_std.cpu()
238-
# Safe inverse transform
239-
orig_raw = orig_tensor * data_std + data_mean
240-
recon_raw = recon_tensor * data_std + data_mean
241-
else:
242-
orig_raw = orig_tensor
243-
recon_raw = recon_tensor
244312

245-
# Show metrics
246-
if last_result:
247-
# last_result is GIAResults object
248-
# It contains rmse_score, mae_score etc for tabular
249-
logger.info(f"Final Metrics: RMSE={last_result.rmse_score:.4f}, MAE={last_result.mae_score:.4f}")
250-
251-
# We can also re-calculate our custom detailed metrics if desired
252-
full_metrics = evaluate_reconstruction(
253-
attacker.original.to(attacker.original.device),
254-
recon_tensor.to(attacker.original.device),
255-
encoder_meta,
256-
return_per_feature=True
257-
)
258-
agg = full_metrics['aggregate']
259-
logger.info(f"Detailed Metrics: Numerical Score={agg['numerical_score']:.4f}, Categorical Score={agg['categorical_score']:.4f}")
260-
print(f"FINAL_METRICS: Numerical={agg['numerical_score']:.4f} Categorical={agg['categorical_score']:.4f}")
261-
262-
# Show best row
263-
errors = torch.sum(torch.abs(orig_tensor - recon_tensor), dim=1)
264-
idx_best = int(errors.argmin().item())
265-
logger.info("Best row (idx=%d) original: %s", idx_best, orig_raw[idx_best].numpy())
266-
logger.info("Best row (idx=%d) reconstructed: %s", idx_best, recon_raw[idx_best].numpy())
267-
268-
269-
270-
def main(protocol: str = "fedsgd") -> None:
313+
def main(protocol: str = "fedsgd", results_path: str = "results.txt") -> None:
271314
seed_everything(42)
272315
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
273316
base_cfg, ckpt_path = load_config()
@@ -298,11 +341,38 @@ def main(protocol: str = "fedsgd") -> None:
298341
"ckpt" if ckpt_std is not None else "loader",
299342
)
300343

301-
run_attack(model, client_loader, data_mean, data_std, saved_encoder_meta, loaders["train_loader"])
344+
# run the attack on the model with the client data
345+
attacker, last_result, data_mean, data_std, encoder_meta, train_loader = run_attack(model, client_loader, data_mean, data_std, saved_encoder_meta, loaders["train_loader"])
346+
347+
# metrics from attack
348+
# last_result is GIAResults object. it contains rmse_score, mae_score, etc for tabular
349+
if last_result is not None:
350+
logger.info("Final Metrics: RMSE=%.4f, MAE=%.4f", last_result.rmse_score, last_result.mae_score)
351+
352+
# Extract tensors for evaluation
353+
if attacker.best_reconstruction:
354+
recon_tensor = torch.cat([batch[0] for batch in attacker.best_reconstruction], dim=0).detach().cpu()
355+
else:
356+
recon_tensor = torch.zeros_like(attacker.original.cpu())
357+
orig_tensor = attacker.original.detach().cpu()
358+
359+
labels = getattr(attacker, "reconstruction_labels", None)
360+
if labels is None:
361+
label_tensor = None
362+
elif isinstance(labels, list):
363+
label_tensor = torch.stack(labels).view(-1).detach().cpu()
364+
else:
365+
label_tensor = torch.as_tensor(labels).view(-1).detach().cpu()
366+
367+
# apply hungarian matching
368+
orig_tensor, recon_tensor, label_tensor = apply_matching(orig_tensor, recon_tensor, label_tensor)
369+
# run evaluation of the attack
370+
evaluate_inversion(attacker, orig_tensor, recon_tensor, encoder_meta, train_loader, results_path, label_tensor)
302371

303372

304373
if __name__ == "__main__":
305374
parser = argparse.ArgumentParser(description="Tabular GIA demo")
306375
parser.add_argument("--protocol", choices=["fedavg", "fedsgd"], default="fedsgd", help="Full client split (fedavg) or single mini-batch (fedsgd)")
376+
parser.add_argument("--results-path", default="results.txt", help="Results path")
307377
parsed = parser.parse_args()
308-
main(protocol=parsed.protocol)
378+
main(protocol=parsed.protocol, results_path=parsed.results_path)

examples/gia/GIA_base_running_tabular/model.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33

44

5-
class TabularMLP2(nn.Module):
5+
class TabularMLP1(nn.Module):
66
"""Simple two-layer MLP."""
77

88
def __init__(self, d_in: int, d_hidden: int = 64, num_classes: int = 2) -> None:
@@ -16,7 +16,7 @@ def __init__(self, d_in: int, d_hidden: int = 64, num_classes: int = 2) -> None:
1616
def forward(self, x: torch.Tensor) -> torch.Tensor:
1717
return self.net(x)
1818

19-
class TabularMLP(nn.Module):
19+
class TabularMLP2(nn.Module):
2020
"""Deeper MLP with normalization and dropout."""
2121

2222
def __init__(self, d_in: int, d_hidden: int = 256, num_classes: int = 2, dropout: float = 0.1) -> None:
@@ -33,4 +33,40 @@ def __init__(self, d_in: int, d_hidden: int = 256, num_classes: int = 2, dropout
3333
)
3434

3535
def forward(self, x: torch.Tensor) -> torch.Tensor:
36-
return self.net(x)
36+
return self.net(x)
37+
38+
39+
class ResidualMLP(nn.Module):
40+
"""Residual block for tabular MLP."""
41+
42+
def __init__(self, dim: int, dropout: float) -> None:
43+
super().__init__()
44+
self.net = nn.Sequential(
45+
nn.LayerNorm(dim),
46+
nn.GELU(),
47+
nn.Dropout(dropout),
48+
nn.Linear(dim, dim),
49+
)
50+
51+
def forward(self, x: torch.Tensor) -> torch.Tensor:
52+
return x + self.net(x)
53+
54+
55+
class TabularMLP(nn.Module):
56+
"""Stronger MLP with residual blocks."""
57+
58+
def __init__(self, d_in: int, d_hidden: int = 512, num_classes: int = 2, dropout: float = 0.3, n_blocks: int = 3) -> None:
59+
super().__init__()
60+
self.in_proj = nn.Linear(d_in, d_hidden)
61+
self.blocks = nn.Sequential(*[ResidualMLP(d_hidden, dropout) for _ in range(n_blocks)])
62+
self.head = nn.Sequential(
63+
nn.LayerNorm(d_hidden),
64+
nn.GELU(),
65+
nn.Dropout(dropout),
66+
nn.Linear(d_hidden, num_classes),
67+
)
68+
69+
def forward(self, x: torch.Tensor) -> torch.Tensor:
70+
x = self.in_proj(x)
71+
x = self.blocks(x)
72+
return self.head(x)

0 commit comments

Comments
 (0)