Skip to content

Commit 52e40fe

Browse files
authored
GAN fixes (#3)
1 parent ba255b3 commit 52e40fe

File tree

5 files changed

+105
-102
lines changed

5 files changed

+105
-102
lines changed

decaf/DECAF.py

Lines changed: 94 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from collections import OrderedDict
2-
from typing import Any, Optional, Union
2+
from typing import Any, List, Optional, Union
33

44
import networkx as nx
55
import numpy as np
66
import pytorch_lightning as pl
7-
import scipy.linalg as slin
87
import torch
98
import torch.nn as nn
109

@@ -13,13 +12,32 @@
1312
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1413

1514

15+
def get_nonlin(name: str) -> nn.Module:
16+
if name == "none":
17+
return nn.Identity()
18+
elif name == "elu":
19+
return nn.ELU()
20+
elif name == "relu":
21+
return nn.ReLU()
22+
elif name == "leaky_relu":
23+
return nn.LeakyReLU()
24+
elif name == "selu":
25+
return nn.SELU()
26+
elif name == "tanh":
27+
return nn.Tanh()
28+
elif name == "sigmoid":
29+
return nn.Sigmoid()
30+
elif name == "softmax":
31+
return nn.Softmax(dim=-1)
32+
else:
33+
raise ValueError(f"Unknown nonlinearity {name}")
34+
35+
1636
class TraceExpm(torch.autograd.Function):
1737
@staticmethod
18-
def forward(ctx: Any, input: torch.Tensor) -> torch.Tensor:
19-
# detach so we can cast to NumPy
20-
E = slin.expm(input.detach().numpy())
21-
f = np.trace(E)
22-
E = torch.from_numpy(E).to(DEVICE)
38+
def forward(ctx: Any, data: torch.Tensor) -> torch.Tensor:
39+
E = torch.linalg.matrix_exp(data)
40+
f = torch.trace(E)
2341
ctx.save_for_backward(E)
2442
return torch.as_tensor(f, dtype=input.dtype)
2543

@@ -32,49 +50,57 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
3250

3351
trace_expm = TraceExpm.apply
3452

35-
activation_layer = nn.ReLU(inplace=True)
36-
3753

3854
class Generator_causal(nn.Module):
3955
def __init__(
4056
self,
4157
z_dim: int,
4258
x_dim: int,
4359
h_dim: int,
44-
use_mask: bool = False,
4560
f_scale: float = 0.1,
4661
dag_seed: list = [],
62+
nonlin_out: Optional[List] = None,
4763
) -> None:
4864
super().__init__()
4965

66+
if nonlin_out is not None:
67+
out_dim = 0
68+
for act, length in nonlin_out:
69+
out_dim += length
70+
if out_dim != x_dim:
71+
raise RuntimeError("Invalid nonlin_out")
72+
5073
self.x_dim = x_dim
74+
self.nonlin_out = nonlin_out
5175

5276
def block(in_feat: int, out_feat: int, normalize: bool = False) -> list:
5377
layers = [nn.Linear(in_feat, out_feat)]
5478
if normalize:
5579
layers.append(nn.BatchNorm1d(out_feat, 0.8))
56-
layers.append(activation_layer)
80+
layers.append(nn.ReLU(inplace=True))
5781
return layers
5882

59-
self.shared = nn.Sequential(*block(h_dim, h_dim), *block(h_dim, h_dim))
60-
61-
if use_mask:
83+
self.shared = nn.Sequential(*block(h_dim, h_dim), *block(h_dim, h_dim)).to(
84+
DEVICE
85+
)
6286

63-
if len(dag_seed) > 0:
64-
M_init = torch.rand(x_dim, x_dim) * 0.0
65-
M_init[torch.eye(x_dim, dtype=bool)] = 0
66-
M_init = torch.rand(x_dim, x_dim) * 0.0
67-
for pair in dag_seed:
68-
M_init[pair[0], pair[1]] = 1
87+
if len(dag_seed) > 0:
88+
M_init = torch.rand(x_dim, x_dim) * 0.0
89+
M_init[torch.eye(x_dim, dtype=bool)] = 0
90+
M_init = torch.rand(x_dim, x_dim) * 0.0
91+
for pair in dag_seed:
92+
M_init[pair[0], pair[1]] = 1
6993

70-
self.M = torch.nn.parameter.Parameter(M_init, requires_grad=False)
71-
print("Initialised adjacency matrix as parsed:\n", self.M)
72-
else:
73-
M_init = torch.rand(x_dim, x_dim) * 0.2
74-
M_init[torch.eye(x_dim, dtype=bool)] = 0
75-
self.M = torch.nn.parameter.Parameter(M_init)
94+
M_init = M_init.to(DEVICE)
95+
self.M = torch.nn.parameter.Parameter(M_init, requires_grad=False).to(
96+
DEVICE
97+
)
7698
else:
77-
self.M = torch.ones(x_dim, x_dim)
99+
M_init = torch.rand(x_dim, x_dim) * 0.2
100+
M_init[torch.eye(x_dim, dtype=bool)] = 0
101+
M_init = M_init.to(DEVICE)
102+
self.M = torch.nn.parameter.Parameter(M_init).to(DEVICE)
103+
78104
self.fc_i = nn.ModuleList(
79105
[nn.Linear(x_dim + 1, h_dim) for i in range(self.x_dim)]
80106
)
@@ -111,13 +137,28 @@ def sequential(
111137
x_masked[:, i] = 0.0
112138
if i in biased_edges:
113139
for j in biased_edges[i]:
114-
x_j = x_masked[:, j].detach().numpy()
115-
np.random.shuffle(x_j)
116-
x_masked[:, j] = torch.from_numpy(x_j)
117-
out_i = activation_layer(
118-
self.fc_i[i](torch.cat([x_masked, z[:, i].unsqueeze(1)], axis=1))
119-
)
120-
out[:, i] = nn.Sigmoid()(self.fc_f[i](self.shared(out_i))).squeeze()
140+
x_j = x_masked[:, j]
141+
perm = torch.randperm(len(x_j))
142+
x_masked[:, j] = x_j[perm]
143+
out_i = self.fc_i[i](torch.cat([x_masked, z[:, i].unsqueeze(1)], axis=1))
144+
out_i = nn.ReLU()(out_i)
145+
out_i = self.shared(out_i)
146+
out_i = self.fc_f[i](out_i).squeeze()
147+
out[:, i] = out_i
148+
149+
if self.nonlin_out is not None:
150+
split = 0
151+
for act_name, step in self.nonlin_out:
152+
activation = get_nonlin(act_name)
153+
out[..., split : split + step] = activation(
154+
out[..., split : split + step]
155+
)
156+
157+
split += step
158+
159+
if split != out.shape[-1]:
160+
raise ValueError("Invalid activations")
161+
121162
return out
122163

123164

@@ -127,9 +168,9 @@ def __init__(self, x_dim: int, h_dim: int) -> None:
127168

128169
self.model = nn.Sequential(
129170
nn.Linear(x_dim, h_dim),
130-
activation_layer,
171+
nn.ReLU(),
131172
nn.Linear(h_dim, h_dim),
132-
activation_layer,
173+
nn.ReLU(),
133174
nn.Linear(h_dim, 1),
134175
)
135176

@@ -153,16 +194,14 @@ def __init__(
153194
batch_size: int = 32,
154195
lambda_gp: float = 10,
155196
lambda_privacy: float = 1,
156-
d_updates: int = 5,
157197
eps: float = 1e-8,
158198
alpha: float = 1,
159199
rho: float = 1,
160200
weight_decay: float = 1e-2,
161201
grad_dag_loss: bool = False,
162202
l1_g: float = 0,
163203
l1_W: float = 1,
164-
p_gen: float = -1,
165-
use_mask: bool = False,
204+
nonlin_out: Optional[List] = None,
166205
):
167206
super().__init__()
168207
self.save_hyperparameters()
@@ -183,8 +222,8 @@ def __init__(
183222
z_dim=self.z_dim,
184223
x_dim=self.x_dim,
185224
h_dim=h_dim,
186-
use_mask=use_mask,
187225
dag_seed=dag_seed,
226+
nonlin_out=nonlin_out,
188227
).to(DEVICE)
189228
self.discriminator = Discriminator(x_dim=self.x_dim, h_dim=h_dim).to(DEVICE)
190229

@@ -261,21 +300,7 @@ def privacy_loss(
261300
)
262301

263302
def get_W(self) -> torch.Tensor:
264-
if self.hparams.use_mask:
265-
return self.generator.M
266-
else:
267-
W_0 = []
268-
for i in range(self.x_dim):
269-
weights = self.generator.fc_i[i].weight[
270-
:, :-1
271-
] # don't take the noise variable's weights
272-
W_0.append(
273-
torch.sqrt(
274-
torch.sum((weights) ** 2, axis=0, keepdim=True)
275-
+ self.hparams.eps
276-
)
277-
)
278-
return torch.cat(W_0, axis=0).T
303+
return self.generator.M
279304

280305
def dag_loss(self) -> torch.Tensor:
281306
W = self.get_W()
@@ -288,7 +313,7 @@ def dag_loss(self) -> torch.Tensor:
288313
)
289314

290315
def sample_z(self, n: int) -> torch.Tensor:
291-
return torch.rand(n, self.z_dim) * 2 - 1
316+
return torch.randn(n, self.z_dim, device=DEVICE)
292317

293318
@staticmethod
294319
def l1_reg(model: nn.Module) -> float:
@@ -298,9 +323,10 @@ def l1_reg(model: nn.Module) -> float:
298323
l1 = l1 + layer.norm(p=1)
299324
return l1
300325

301-
def gen_synthetic(
302-
self, x: torch.Tensor, gen_order: Optional[list] = None, biased_edges: dict = {}
303-
) -> torch.Tensor:
326+
def gen_synthetic(self, x: torch.Tensor, biased_edges: dict = {}) -> torch.Tensor:
327+
self.generator = self.generator.to(DEVICE)
328+
x = x.to(DEVICE)
329+
gen_order = self.get_gen_order()
304330
return self.generator.sequential(
305331
x,
306332
self.sample_z(x.shape[0]).type_as(x),
@@ -309,15 +335,7 @@ def gen_synthetic(
309335
)
310336

311337
def get_dag(self) -> np.ndarray:
312-
return np.round(self.get_W().detach().numpy(), 3)
313-
314-
def get_bi_dag(self) -> np.ndarray:
315-
dag = np.round(self.get_W().detach().numpy(), 3)
316-
bi_dag = np.zeros_like(dag)
317-
for i in range(len(dag)):
318-
for j in range(i, len(dag)):
319-
bi_dag[i][j] = dag[i][j] + dag[j][i]
320-
return np.round(bi_dag, 3)
338+
return np.round(self.get_W().detach().cpu().numpy(), 3)
321339

322340
def get_gen_order(self) -> list:
323341
dense_dag = np.array(self.get_dag())
@@ -333,13 +351,8 @@ def training_step(
333351
# sample noise
334352
z = self.sample_z(batch.shape[0])
335353
z = z.type_as(batch)
354+
generated_batch = self.generator.sequential(batch, z, self.get_gen_order())
336355

337-
if self.hparams.p_gen < 0:
338-
generated_batch = self.generator.sequential(batch, z, self.get_gen_order())
339-
else: # train simultaneously
340-
raise ValueError(
341-
"we're not allowing simultaneous generation no more. Set p_gen negative"
342-
)
343356
# train generator
344357
if optimizer_idx == 0:
345358
self.iterations_d += 1
@@ -356,12 +369,10 @@ def training_step(
356369
d_loss += self.hparams.lambda_gp * self.compute_gradient_penalty(
357370
batch, generated_batch
358371
)
372+
if torch.isnan(d_loss).sum() != 0:
373+
raise ValueError("NaN in the discr loss")
359374

360-
tqdm_dict = {"d_loss": d_loss.detach()}
361-
output = OrderedDict(
362-
{"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
363-
)
364-
return output
375+
return d_loss
365376
elif optimizer_idx == 1:
366377
# sanity check: keep track of G updates
367378
self.iterations_g += 1
@@ -382,14 +393,10 @@ def training_step(
382393
if len(self.dag_seed) == 0:
383394
if self.hparams.grad_dag_loss:
384395
g_loss += self.gradient_dag_loss(batch, z)
396+
if torch.isnan(g_loss).sum() != 0:
397+
raise ValueError("NaN in the gen loss")
385398

386-
tqdm_dict = {"g_loss": g_loss.detach()}
387-
388-
output = OrderedDict(
389-
{"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
390-
)
391-
392-
return output
399+
return g_loss
393400
else:
394401
raise ValueError("should not get here")
395402

@@ -411,7 +418,4 @@ def configure_optimizers(self) -> tuple:
411418
betas=(b1, b2),
412419
weight_decay=weight_decay,
413420
)
414-
return (
415-
{"optimizer": opt_d, "frequency": self.hparams.d_updates},
416-
{"optimizer": opt_g, "frequency": 1},
417-
)
421+
return [opt_d, opt_g], []

decaf/data.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
import decaf.logger as log
1010

11+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
1113

1214
class Dataset(torch.utils.data.Dataset):
1315
def __init__(self, data: list) -> None:
1416
data = np.array(data, dtype="float32")
15-
self.x = torch.from_numpy(data)
17+
self.x = torch.from_numpy(data).to(DEVICE)
1618
self.n_samples = self.x.shape[0]
1719
log.info("***** DATA ****")
1820
log.info(f"n_samples = {self.n_samples}")
@@ -49,10 +51,14 @@ def train_dataloader(self) -> DataLoader:
4951

5052
def val_dataloader(self) -> DataLoader:
5153
return DataLoader(
52-
self.data_val, batch_size=self.batch_size, num_workers=self.num_workers
54+
self.data_val,
55+
batch_size=self.batch_size,
56+
num_workers=self.num_workers,
5357
)
5458

5559
def test_dataloader(self) -> DataLoader:
5660
return DataLoader(
57-
self.data_test, batch_size=self.batch_size, num_workers=self.num_workers
61+
self.data_test,
62+
batch_size=self.batch_size,
63+
num_workers=self.num_workers,
5864
)

decaf/logger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
LOG_FORMAT = "[{time}][{process.id}][{level}] {message}"
1010

11-
logger.remove()
1211
DEFAULT_SINK = "decaf_{time}.log"
1312

1413

decaf/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.2"
1+
__version__ = "0.1.3"

tests/test_decaf.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,5 @@ def test_sanity_generate() -> None:
7272

7373
trainer.fit(model, dummy_dm)
7474

75-
synth_data = (
76-
model.gen_synthetic(
77-
raw_data, gen_order=model.get_gen_order(), biased_edges=bias_dict
78-
)
79-
.detach()
80-
.numpy()
81-
)
75+
synth_data = model.gen_synthetic(raw_data, biased_edges=bias_dict)
8276
assert synth_data.shape[0] == 10

0 commit comments

Comments
 (0)