11import numpy as np
2+ from numpy .random import normal , uniform
3+
24import torch
3- from numpy .random import normal , uniform , laplace
45from torch .utils .data import Dataset
56
7+ from typing import Callable
8+
69
710class LinAddSEM :
8- """
9- Defines a linear additive SEM and sampling operations.
10- """
11- def __init__ (self , noise_mean , noise_stds , adj_mat , noise_dist = normal ):
12- """Initialize SEM.
11+ """Defines a linear additive SEM and sampling operations."""
1312
14- All input variables should be np.array.
13+ def __init__ (
14+ self ,
15+ noise_mean : np .ndarray ,
16+ noise_stds : np .ndarray ,
17+ adj_mat : np .ndarray ,
18+ noise_dist : Callable = normal
19+ ):
20+ """Initialize Linear Additive SEM.
1521
1622 Assumes autoregressive causal ordering, meaning adjacency matrix must
1723 be lower triangular.
24+
25+ Args:
26+ noise_mean: Means of the noise distributions. Has shape (D,).
27+ noise_stds: Standard dev. of noise distributions. Has shape (D,).
28+ adj_mat: Adjacency matrix of SEM. Has shape (D, D).
29+ noise_dist: Noise generating distribution
1830 """
1931 # Check that SEM specification is valid
20- assert ( len (noise_mean ) == len (noise_stds ) )
32+ assert len (noise_mean ) == len (noise_stds )
2133
22- assert ( len (adj_mat .shape ) == 2 )
23- assert ( adj_mat .shape [0 ] == adj_mat .shape [1 ])
24- assert ( len (noise_mean ) == adj_mat .shape [0 ])
25- assert ( np .allclose (adj_mat , np .tril (adj_mat ) ))
34+ assert len (adj_mat .shape ) == 2
35+ assert adj_mat .shape [0 ] == adj_mat .shape [1 ]
36+ assert len (noise_mean ) == adj_mat .shape [0 ]
37+ assert np .allclose (adj_mat , np .tril (adj_mat ))
2638
2739 self .n_var = len (noise_mean )
2840
@@ -31,8 +43,12 @@ def __init__(self, noise_mean, noise_stds, adj_mat, noise_dist=normal):
3143 self .noise_dist = noise_dist
3244 self .adj_mat = adj_mat
3345
34- def generate_sample (self ):
35- """Generates a sample from specified SEM."""
46+ def generate_sample (self ) -> np .ndarray :
47+ """Generate a sample from specified SEM.
48+
49+ Returns:
50+ Single sample generated from SEM
51+ """
3652 e = self .noise_dist (self .noise_mean , self .noise_stds )
3753
3854 out_mat = np .zeros_like (e )
@@ -42,10 +58,25 @@ def generate_sample(self):
4258
4359 return out_mat
4460
45- def generate_samples (self , n_samp ):
61+ def generate_samples (self , n_samp : int ) -> np .ndarray :
62+ """Generate multiple samples from SEM.
63+
64+ Returns:
65+ Samples from SEM in shape (n_samp, n_dim)
66+ """
4667 return np .array ([self .generate_sample () for _ in range (n_samp )])
4768
48- def generate_intervention (self , int_val ):
69+ def generate_intervention (self , int_val : list [float | None ]) -> np .ndarray :
70+ """Generate ground truth intervention on variables.
71+
72+ Args:
73+ int_val:
74+ Intervenational values. Has shape (D,), but each position can
75+ be None to indicate no intervention in specific variable.
76+
77+ Returns:
78+ Sample generated under intervention
79+ """
4980 e = self .noise_dist (self .noise_mean , self .noise_stds )
5081
5182 out_mat = np .zeros_like (e )
@@ -58,7 +89,26 @@ def generate_intervention(self, int_val):
5889
5990 return out_mat
6091
61- def generate_int_dist (self , int_val , n_samp , return_mean = True ):
92+ def generate_int_dist (
93+ self ,
94+ int_val : list [float | None ],
95+ n_samp : int ,
96+ return_mean : bool = True
97+ ) -> np .ndarray :
98+ """Estimate ground truth interventional distribution.
99+
100+ Generates samples under intervention, and optionally returns mean.
101+
102+ Args:
103+ int_val:
104+ Intervenational values. Has shape (D,), but each position can
105+ be None to indicate no intervention in specific variable.
106+ n_samp: Number of samples used to estimate distribution
107+ return_mean: Whether mean of samples should be taken
108+
109+ Returns:
110+ Interventional samples, or mean of interventional samples
111+ """
62112 samples = []
63113 for _ in range (n_samp ):
64114 samples .append (self .generate_intervention (int_val ))
@@ -68,8 +118,15 @@ def generate_int_dist(self, int_val, n_samp, return_mean=True):
68118 else :
69119 return np .array (samples )
70120
71- def generate_ctf_obs (self ):
72- """Generates an obs from specified SEM, return both obs and noise."""
121+ def generate_ctf_obs (self ) -> tuple [np .ndarray , np .ndarray ]:
122+ """Generate an sample from specified SEM, return both obs and noise.
123+
124+ Return of noise that generated sample can be used to generate ground
125+ truth values for counterfactual queries.
126+
127+ Returns:
128+ Sample from SEM, and noise that generated sample
129+ """
73130 e = self .noise_dist (self .noise_mean , self .noise_stds )
74131
75132 out_mat = np .zeros_like (e )
@@ -79,8 +136,22 @@ def generate_ctf_obs(self):
79136
80137 return out_mat , e
81138
82- def generate_counterfactual (self , e , ctf_val ):
139+ def generate_counterfactual (
140+ self ,
141+ e : np .ndarray ,
142+ ctf_val : list [float | None ]
143+ ) -> np .ndarray :
144+ """Generate ground truth counterfactual outcome.
83145
146+ Args:
147+ e: Noise used to generate original sample of interest
148+ ctf_val:
149+ Counterfactual values. Has shape (D,), but each position can
150+ be None to indicate no counterfactual in specific variable.
151+
152+ Returns:
153+ Counterfactual outcome of sample
154+ """
84155 out_mat = np .zeros_like (e )
85156
86157 for i , row in enumerate (self .adj_mat ):
@@ -91,29 +162,46 @@ def generate_counterfactual(self, e, ctf_val):
91162
92163 return out_mat
93164
94- def get_carefl_ds (self , n_samp ):
165+ def get_carefl_ds (
166+ self ,
167+ n_samp : int
168+ ) -> tuple [np .ndarray , None , np .ndarray ]:
169+ """Generate dataset using SEM.
170+
171+ Args:
172+ n_samp: Number of samples in dataset
173+
174+ Returns:
175+ Samples, unused, and adjacency matrix
176+ """
95177 X = self .generate_samples (n_samp )
96178
97- # Generate binary adjacency matrix
98- cfl_adj_mat = (self .adj_mat != 0 ).astype (int )
99- np .fill_diagonal (cfl_adj_mat , 0 )
179+ return X , None , self .get_adj_mat ()
100180
101- return X , None , cfl_adj_mat
181+ def get_adj_mat (self ) -> np .ndarray :
182+ """Return adjacency matrix associated with SEM."""
183+ bin_adj_mat = (self .adj_mat != 0 ).astype (int )
184+ np .fill_diagonal (bin_adj_mat , 0 )
185+ return bin_adj_mat
102186
103187
104- class RandomSEM :
105- """Initializes a random LinAddSEM."""
188+ class RandomSEM ( LinAddSEM ) :
189+ """Initializes a LinAddSEM with randomly sampled coefficients ."""
106190
107- def __init__ (self , dimension , noise_mean_param = (- 2 , 2 ),
108- noise_std_param = (1 , 10 ), adj_gen_param = (- 2 , 2 )):
109- """Initialize SEM.
191+ def __init__ (
192+ self ,
193+ dimension : int ,
194+ noise_mean_param : tuple [float , float ] = (- 2 , 2 ),
195+ noise_std_param : tuple [float , float ] = (1 , 10 ),
196+ adj_gen_param : tuple [float , float ] = (- 2 , 2 )
197+ ):
198+ """Initialize SEM with random coefficients.
110199
111200 Args:
112- dimension (int): Size of the graph.
113- noise_mean_param (float, float): Parameters to generate noise mean.
114- noise_std_param (float, float): Parameters to generate noise std.
115- adj_gen_param (float, float): Parameters to generate adjacency
116- weight matrix.
201+ dimension: Size of the graph
202+ noise_mean_param: Parameters to generate noise mean
203+ noise_std_param: Parameters to generate noise std
204+ adj_gen_param: Parameters to generate adjacency weight matrix
117205 """
118206 self .n_var = dimension
119207
@@ -124,33 +212,33 @@ def __init__(self, dimension, noise_mean_param=(-2, 2),
124212 adj_mat = uniform (* adj_gen_param , size = (dimension , dimension ))
125213 self .adj_mat = np .tril (adj_mat )
126214
127- self .sem = LinAddSEM (self .noise_means , self .noise_stds , self .adj_mat )
128-
129- def generate_samples (self , n_samples ):
130- return self .sem .generate_samples (n_samples )
215+ super ().__init__ (self .noise_means , self .noise_stds , self .adj_mat )
131216
132- def generate_int_dist (self , int_val , n_samples ):
133- return self .sem .generate_int_dist (int_val , n_samples )
134217
135- def get_adj_mat (self ):
136- bin_adj_mat = (self .adj_mat != 0 ).astype (int )
137- np .fill_diagonal (bin_adj_mat , 0 )
138- return bin_adj_mat
139-
140-
141- class SparseSEM :
218+ class SparseSEM (LinAddSEM ):
142219 """Initializes a LinAddSEM with many independencies."""
143220
144- def __init__ (self , dimension , noise_mean_param = (- 1 , 1 ),
145- noise_std_param = (1 , 1 ), adj_gen_param = (- 2 , 2 )):
146- """Initialize SEM.
221+ def __init__ (
222+ self ,
223+ dimension : int ,
224+ noise_mean_param : tuple [int , int ] = (- 1 , 1 ),
225+ noise_std_param : tuple [int , int ] = (1 , 1 ),
226+ adj_gen_param : tuple [int , int ] = (- 2 , 2 )
227+ ):
228+ """Initialize a SEM with a highly sparse adjacency.
147229
148230 Args:
149- dimension (int): Size of the graph.
150- noise_mean_param (float, float): Parameters to generate noise mean.
151- noise_std_param (float, float): Parameters to generate noise std.
152- adj_gen_param (float, float): Parameters to generate adjacency
153- weight matrix.
231+ dimension: Size of the graph
232+ noise_mean_param:
233+ Range of uniform distribution used to generate noise
234+ distribution means.
235+ noise_std_param:
236+ Range of uniform distribution used to generate noise
237+ distribution standard deviations.
238+ adj_gen_param:
239+ Range of uniform distribution used to generate DAG edge
240+ coefficients. Note that values less than 1.5 in absolute
241+ value are rounded to zero.
154242 """
155243 self .n_var = dimension
156244
@@ -167,44 +255,34 @@ def __init__(self, dimension, noise_mean_param=(-1, 1),
167255
168256 self .adj_mat = adj_mat
169257
170- self .sem = LinAddSEM (self .noise_means , self .noise_stds , self .adj_mat )
171-
172- def generate_samples (self , n_samples ):
173- return self .sem .generate_samples (n_samples )
174-
175- def generate_int_dist (self , int_val , n_samples , return_mean = True ):
176- return self .sem .generate_int_dist (int_val , n_samples , return_mean )
177-
178- def generate_ctf_obs (self ):
179- return self .sem .generate_ctf_obs ()
258+ super ().__init__ (self .noise_means , self .noise_stds , self .adj_mat )
180259
181- def generate_counterfactual (self , e , ctf_val ):
182- return self .sem .generate_counterfactual (e , ctf_val )
183-
184- def get_adj_mat (self ):
185- bin_adj_mat = (self .adj_mat != 0 ).astype (int )
186- np .fill_diagonal (bin_adj_mat , 0 )
187- return bin_adj_mat
188-
189260
190261class CustomSyntheticDatasetDensity (Dataset ):
191- def __init__ (self , X , device = 'cpu' ):
262+ """PyTorch Dataset wrapper for Causal SEMs."""
263+
264+ def __init__ (self , X : np .ndarray , device : str = 'cpu' ):
265+ """Initialize torch dataset used to wrap causal SEMs."""
192266 self .device = device
193267 self .x = torch .from_numpy (X ).to (device )
194268 self .len = self .x .shape [0 ]
195269 self .data_dim = self .x .shape [1 ]
196270
197- def get_dims (self ):
271+ def get_dims (self ) -> int :
272+ """Get feature dimensionality of data."""
198273 return self .data_dim
199274
200- def __len__ (self ):
275+ def __len__ (self ) -> int :
276+ """Return length of dataset."""
201277 return self .len
202278
203- def __getitem__ (self , index ):
279+ def __getitem__ (self , index : int ) -> torch .Tensor :
280+ """Return single datum from dataset."""
204281 return self .x [index ]
205282
206- def get_metadata (self ):
283+ def get_metadata (self ) -> dict :
284+ """Return dataset statistics."""
207285 return {
208286 'n' : self .len ,
209287 'data_dim' : self .data_dim ,
210- }
288+ }
0 commit comments