11from collections import OrderedDict
2- from typing import Any , Optional , Union
2+ from typing import Any , List , Optional , Union
33
44import networkx as nx
55import numpy as np
66import pytorch_lightning as pl
7- import scipy .linalg as slin
87import torch
98import torch .nn as nn
109
1312DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1413
1514
15+ def get_nonlin (name : str ) -> nn .Module :
16+ if name == "none" :
17+ return nn .Identity ()
18+ elif name == "elu" :
19+ return nn .ELU ()
20+ elif name == "relu" :
21+ return nn .ReLU ()
22+ elif name == "leaky_relu" :
23+ return nn .LeakyReLU ()
24+ elif name == "selu" :
25+ return nn .SELU ()
26+ elif name == "tanh" :
27+ return nn .Tanh ()
28+ elif name == "sigmoid" :
29+ return nn .Sigmoid ()
30+ elif name == "softmax" :
31+ return nn .Softmax (dim = - 1 )
32+ else :
33+ raise ValueError (f"Unknown nonlinearity { name } " )
34+
35+
1636class TraceExpm (torch .autograd .Function ):
1737 @staticmethod
18- def forward (ctx : Any , input : torch .Tensor ) -> torch .Tensor :
19- # detach so we can cast to NumPy
20- E = slin .expm (input .detach ().numpy ())
21- f = np .trace (E )
22- E = torch .from_numpy (E ).to (DEVICE )
38+ def forward (ctx : Any , data : torch .Tensor ) -> torch .Tensor :
39+ E = torch .linalg .matrix_exp (data )
40+ f = torch .trace (E )
2341 ctx .save_for_backward (E )
2442 return torch .as_tensor (f , dtype = input .dtype )
2543
@@ -32,49 +50,57 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
3250
3351trace_expm = TraceExpm .apply
3452
35- activation_layer = nn .ReLU (inplace = True )
36-
3753
3854class Generator_causal (nn .Module ):
3955 def __init__ (
4056 self ,
4157 z_dim : int ,
4258 x_dim : int ,
4359 h_dim : int ,
44- use_mask : bool = False ,
4560 f_scale : float = 0.1 ,
4661 dag_seed : list = [],
62+ nonlin_out : Optional [List ] = None ,
4763 ) -> None :
4864 super ().__init__ ()
4965
66+ if nonlin_out is not None :
67+ out_dim = 0
68+ for act , length in nonlin_out :
69+ out_dim += length
70+ if out_dim != x_dim :
71+ raise RuntimeError ("Invalid nonlin_out" )
72+
5073 self .x_dim = x_dim
74+ self .nonlin_out = nonlin_out
5175
5276 def block (in_feat : int , out_feat : int , normalize : bool = False ) -> list :
5377 layers = [nn .Linear (in_feat , out_feat )]
5478 if normalize :
5579 layers .append (nn .BatchNorm1d (out_feat , 0.8 ))
56- layers .append (activation_layer )
80+ layers .append (nn . ReLU ( inplace = True ) )
5781 return layers
5882
59- self .shared = nn .Sequential (* block (h_dim , h_dim ), * block (h_dim , h_dim ))
60-
61- if use_mask :
83+ self .shared = nn .Sequential (* block (h_dim , h_dim ), * block (h_dim , h_dim )). to (
84+ DEVICE
85+ )
6286
63- if len (dag_seed ) > 0 :
64- M_init = torch .rand (x_dim , x_dim ) * 0.0
65- M_init [torch .eye (x_dim , dtype = bool )] = 0
66- M_init = torch .rand (x_dim , x_dim ) * 0.0
67- for pair in dag_seed :
68- M_init [pair [0 ], pair [1 ]] = 1
87+ if len (dag_seed ) > 0 :
88+ M_init = torch .rand (x_dim , x_dim ) * 0.0
89+ M_init [torch .eye (x_dim , dtype = bool )] = 0
90+ M_init = torch .rand (x_dim , x_dim ) * 0.0
91+ for pair in dag_seed :
92+ M_init [pair [0 ], pair [1 ]] = 1
6993
70- self .M = torch .nn .parameter .Parameter (M_init , requires_grad = False )
71- print ("Initialised adjacency matrix as parsed:\n " , self .M )
72- else :
73- M_init = torch .rand (x_dim , x_dim ) * 0.2
74- M_init [torch .eye (x_dim , dtype = bool )] = 0
75- self .M = torch .nn .parameter .Parameter (M_init )
94+ M_init = M_init .to (DEVICE )
95+ self .M = torch .nn .parameter .Parameter (M_init , requires_grad = False ).to (
96+ DEVICE
97+ )
7698 else :
77- self .M = torch .ones (x_dim , x_dim )
99+ M_init = torch .rand (x_dim , x_dim ) * 0.2
100+ M_init [torch .eye (x_dim , dtype = bool )] = 0
101+ M_init = M_init .to (DEVICE )
102+ self .M = torch .nn .parameter .Parameter (M_init ).to (DEVICE )
103+
78104 self .fc_i = nn .ModuleList (
79105 [nn .Linear (x_dim + 1 , h_dim ) for i in range (self .x_dim )]
80106 )
@@ -111,13 +137,28 @@ def sequential(
111137 x_masked [:, i ] = 0.0
112138 if i in biased_edges :
113139 for j in biased_edges [i ]:
114- x_j = x_masked [:, j ].detach ().numpy ()
115- np .random .shuffle (x_j )
116- x_masked [:, j ] = torch .from_numpy (x_j )
117- out_i = activation_layer (
118- self .fc_i [i ](torch .cat ([x_masked , z [:, i ].unsqueeze (1 )], axis = 1 ))
119- )
120- out [:, i ] = nn .Sigmoid ()(self .fc_f [i ](self .shared (out_i ))).squeeze ()
140+ x_j = x_masked [:, j ]
141+ perm = torch .randperm (len (x_j ))
142+ x_masked [:, j ] = x_j [perm ]
143+ out_i = self .fc_i [i ](torch .cat ([x_masked , z [:, i ].unsqueeze (1 )], axis = 1 ))
144+ out_i = nn .ReLU ()(out_i )
145+ out_i = self .shared (out_i )
146+ out_i = self .fc_f [i ](out_i ).squeeze ()
147+ out [:, i ] = out_i
148+
149+ if self .nonlin_out is not None :
150+ split = 0
151+ for act_name , step in self .nonlin_out :
152+ activation = get_nonlin (act_name )
153+ out [..., split : split + step ] = activation (
154+ out [..., split : split + step ]
155+ )
156+
157+ split += step
158+
159+ if split != out .shape [- 1 ]:
160+ raise ValueError ("Invalid activations" )
161+
121162 return out
122163
123164
@@ -127,9 +168,9 @@ def __init__(self, x_dim: int, h_dim: int) -> None:
127168
128169 self .model = nn .Sequential (
129170 nn .Linear (x_dim , h_dim ),
130- activation_layer ,
171+ nn . ReLU () ,
131172 nn .Linear (h_dim , h_dim ),
132- activation_layer ,
173+ nn . ReLU () ,
133174 nn .Linear (h_dim , 1 ),
134175 )
135176
@@ -153,16 +194,14 @@ def __init__(
153194 batch_size : int = 32 ,
154195 lambda_gp : float = 10 ,
155196 lambda_privacy : float = 1 ,
156- d_updates : int = 5 ,
157197 eps : float = 1e-8 ,
158198 alpha : float = 1 ,
159199 rho : float = 1 ,
160200 weight_decay : float = 1e-2 ,
161201 grad_dag_loss : bool = False ,
162202 l1_g : float = 0 ,
163203 l1_W : float = 1 ,
164- p_gen : float = - 1 ,
165- use_mask : bool = False ,
204+ nonlin_out : Optional [List ] = None ,
166205 ):
167206 super ().__init__ ()
168207 self .save_hyperparameters ()
@@ -183,8 +222,8 @@ def __init__(
183222 z_dim = self .z_dim ,
184223 x_dim = self .x_dim ,
185224 h_dim = h_dim ,
186- use_mask = use_mask ,
187225 dag_seed = dag_seed ,
226+ nonlin_out = nonlin_out ,
188227 ).to (DEVICE )
189228 self .discriminator = Discriminator (x_dim = self .x_dim , h_dim = h_dim ).to (DEVICE )
190229
@@ -261,21 +300,7 @@ def privacy_loss(
261300 )
262301
263302 def get_W (self ) -> torch .Tensor :
264- if self .hparams .use_mask :
265- return self .generator .M
266- else :
267- W_0 = []
268- for i in range (self .x_dim ):
269- weights = self .generator .fc_i [i ].weight [
270- :, :- 1
271- ] # don't take the noise variable's weights
272- W_0 .append (
273- torch .sqrt (
274- torch .sum ((weights ) ** 2 , axis = 0 , keepdim = True )
275- + self .hparams .eps
276- )
277- )
278- return torch .cat (W_0 , axis = 0 ).T
303+ return self .generator .M
279304
280305 def dag_loss (self ) -> torch .Tensor :
281306 W = self .get_W ()
@@ -288,7 +313,7 @@ def dag_loss(self) -> torch.Tensor:
288313 )
289314
290315 def sample_z (self , n : int ) -> torch .Tensor :
291- return torch .rand (n , self .z_dim ) * 2 - 1
316+ return torch .randn (n , self .z_dim , device = DEVICE )
292317
293318 @staticmethod
294319 def l1_reg (model : nn .Module ) -> float :
@@ -298,9 +323,10 @@ def l1_reg(model: nn.Module) -> float:
298323 l1 = l1 + layer .norm (p = 1 )
299324 return l1
300325
301- def gen_synthetic (
302- self , x : torch .Tensor , gen_order : Optional [list ] = None , biased_edges : dict = {}
303- ) -> torch .Tensor :
326+ def gen_synthetic (self , x : torch .Tensor , biased_edges : dict = {}) -> torch .Tensor :
327+ self .generator = self .generator .to (DEVICE )
328+ x = x .to (DEVICE )
329+ gen_order = self .get_gen_order ()
304330 return self .generator .sequential (
305331 x ,
306332 self .sample_z (x .shape [0 ]).type_as (x ),
@@ -309,15 +335,7 @@ def gen_synthetic(
309335 )
310336
311337 def get_dag (self ) -> np .ndarray :
312- return np .round (self .get_W ().detach ().numpy (), 3 )
313-
314- def get_bi_dag (self ) -> np .ndarray :
315- dag = np .round (self .get_W ().detach ().numpy (), 3 )
316- bi_dag = np .zeros_like (dag )
317- for i in range (len (dag )):
318- for j in range (i , len (dag )):
319- bi_dag [i ][j ] = dag [i ][j ] + dag [j ][i ]
320- return np .round (bi_dag , 3 )
338+ return np .round (self .get_W ().detach ().cpu ().numpy (), 3 )
321339
322340 def get_gen_order (self ) -> list :
323341 dense_dag = np .array (self .get_dag ())
@@ -333,13 +351,8 @@ def training_step(
333351 # sample noise
334352 z = self .sample_z (batch .shape [0 ])
335353 z = z .type_as (batch )
354+ generated_batch = self .generator .sequential (batch , z , self .get_gen_order ())
336355
337- if self .hparams .p_gen < 0 :
338- generated_batch = self .generator .sequential (batch , z , self .get_gen_order ())
339- else : # train simultaneously
340- raise ValueError (
341- "we're not allowing simultaneous generation no more. Set p_gen negative"
342- )
343356 # train generator
344357 if optimizer_idx == 0 :
345358 self .iterations_d += 1
@@ -356,12 +369,10 @@ def training_step(
356369 d_loss += self .hparams .lambda_gp * self .compute_gradient_penalty (
357370 batch , generated_batch
358371 )
372+ if torch .isnan (d_loss ).sum () != 0 :
373+ raise ValueError ("NaN in the discr loss" )
359374
360- tqdm_dict = {"d_loss" : d_loss .detach ()}
361- output = OrderedDict (
362- {"loss" : d_loss , "progress_bar" : tqdm_dict , "log" : tqdm_dict }
363- )
364- return output
375+ return d_loss
365376 elif optimizer_idx == 1 :
366377 # sanity check: keep track of G updates
367378 self .iterations_g += 1
@@ -382,14 +393,10 @@ def training_step(
382393 if len (self .dag_seed ) == 0 :
383394 if self .hparams .grad_dag_loss :
384395 g_loss += self .gradient_dag_loss (batch , z )
396+ if torch .isnan (g_loss ).sum () != 0 :
397+ raise ValueError ("NaN in the gen loss" )
385398
386- tqdm_dict = {"g_loss" : g_loss .detach ()}
387-
388- output = OrderedDict (
389- {"loss" : g_loss , "progress_bar" : tqdm_dict , "log" : tqdm_dict }
390- )
391-
392- return output
399+ return g_loss
393400 else :
394401 raise ValueError ("should not get here" )
395402
@@ -411,7 +418,4 @@ def configure_optimizers(self) -> tuple:
411418 betas = (b1 , b2 ),
412419 weight_decay = weight_decay ,
413420 )
414- return (
415- {"optimizer" : opt_d , "frequency" : self .hparams .d_updates },
416- {"optimizer" : opt_g , "frequency" : 1 },
417- )
421+ return [opt_d , opt_g ], []
0 commit comments