22
33import networkx as nx
44import numpy as np
5- import pandas as pd
6- import pytest
75import pytorch_lightning as pl
86import torch
9- from sklearn .metrics import precision_score , recall_score , roc_auc_score
10- from utils import gen_data_nonlinear , load_adult
11- from xgboost import XGBClassifier
7+ from utils import gen_data_nonlinear
128
139from decaf import DECAF , DataModule
1410
@@ -72,7 +68,7 @@ def test_sanity_generate() -> None:
7268 dummy_dm .dims [0 ],
7369 dag_seed = seed ,
7470 )
75- trainer = pl .Trainer (max_epochs = 2 , logger = False )
71+ trainer = pl .Trainer (max_epochs = 100 , logger = True )
7672
7773 trainer .fit (model , dummy_dm )
7874
@@ -84,53 +80,3 @@ def test_sanity_generate() -> None:
8480 .numpy ()
8581 )
8682 assert synth_data .shape [0 ] == 10
87-
88-
89- @pytest .mark .parametrize ("X,y" , [load_adult ()])
90- @pytest .mark .slow
91- def test_run_experiments (X : pd .DataFrame , y : pd .DataFrame ) -> None :
92- baseline_clf = XGBClassifier ().fit (X , y )
93- y_pred = baseline_clf .predict (X )
94-
95- print (
96- "baseline scores" ,
97- precision_score (y , y_pred ),
98- recall_score (y , y_pred ),
99- roc_auc_score (y , y_pred ),
100- )
101-
102- dm = DataModule (X )
103-
104- model = DECAF (
105- dm .dims [0 ],
106- use_mask = True ,
107- grad_dag_loss = False ,
108- lambda_privacy = 0 ,
109- lambda_gp = 10 ,
110- weight_decay = 1e-2 ,
111- l1_g = 0 ,
112- p_gen = - 1 ,
113- batch_size = 100 ,
114- )
115- trainer = pl .Trainer (max_epochs = 10 , logger = False )
116- trainer .fit (model , dm )
117-
118- X_synth = (
119- model .gen_synthetic (
120- dm .dataset .x ,
121- gen_order = model .get_gen_order (),
122- )
123- .detach ()
124- .numpy ()
125- )
126- y_synth = baseline_clf .predict (X_synth )
127-
128- synth_clf = XGBClassifier ().fit (X_synth , y_synth )
129- y_pred = synth_clf .predict (X_synth )
130-
131- print (
132- "synth scores" ,
133- precision_score (y_synth , y_pred ),
134- recall_score (y_synth , y_pred ),
135- roc_auc_score (y_synth , y_pred ),
136- )
0 commit comments