Skip to content

Commit 6bcce2c

Browse files
IanShi1996xgao0o
andauthored
Causal StrAF refactor (#8)
* Initial commit for a causal inference section refactor. Added tests, refactored CausalARFlow into separate model and training classes. * fix causal exps * Changed output directory generation to pathlib. * Added early stopping to CausalAF code. * Fixed minor typing issues. --------- Co-authored-by: Xiang Gao <xgao@cs.toronto.edu>
1 parent 3a12f23 commit 6bcce2c

File tree

11 files changed

+836
-410
lines changed

11 files changed

+836
-410
lines changed

data/causal_sem.py

Lines changed: 158 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,40 @@
11
import numpy as np
2+
from numpy.random import normal, uniform
3+
24
import torch
3-
from numpy.random import normal, uniform, laplace
45
from torch.utils.data import Dataset
56

7+
from typing import Callable
8+
69

710
class LinAddSEM:
8-
"""
9-
Defines a linear additive SEM and sampling operations.
10-
"""
11-
def __init__(self, noise_mean, noise_stds, adj_mat, noise_dist=normal):
12-
"""Initialize SEM.
11+
"""Defines a linear additive SEM and sampling operations."""
1312

14-
All input variables should be np.array.
13+
def __init__(
14+
self,
15+
noise_mean: np.ndarray,
16+
noise_stds: np.ndarray,
17+
adj_mat: np.ndarray,
18+
noise_dist: Callable = normal
19+
):
20+
"""Initialize Linear Additive SEM.
1521
1622
Assumes autoregressive causal ordering, meaning adjacency matrix must
1723
be lower triangular.
24+
25+
Args:
26+
noise_mean: Means of the noise distributions. Has shape (D,).
27+
noise_stds: Standard dev. of noise distributions. Has shape (D,).
28+
adj_mat: Adjacency matrix of SEM. Has shape (D, D).
29+
noise_dist: Noise generating distribution
1830
"""
1931
# Check that SEM specification is valid
20-
assert(len(noise_mean) == len(noise_stds))
32+
assert len(noise_mean) == len(noise_stds)
2133

22-
assert(len(adj_mat.shape) == 2)
23-
assert(adj_mat.shape[0] == adj_mat.shape[1])
24-
assert(len(noise_mean) == adj_mat.shape[0])
25-
assert(np.allclose(adj_mat, np.tril(adj_mat)))
34+
assert len(adj_mat.shape) == 2
35+
assert adj_mat.shape[0] == adj_mat.shape[1]
36+
assert len(noise_mean) == adj_mat.shape[0]
37+
assert np.allclose(adj_mat, np.tril(adj_mat))
2638

2739
self.n_var = len(noise_mean)
2840

@@ -31,8 +43,12 @@ def __init__(self, noise_mean, noise_stds, adj_mat, noise_dist=normal):
3143
self.noise_dist = noise_dist
3244
self.adj_mat = adj_mat
3345

34-
def generate_sample(self):
35-
"""Generates a sample from specified SEM."""
46+
def generate_sample(self) -> np.ndarray:
47+
"""Generate a sample from specified SEM.
48+
49+
Returns:
50+
Single sample generated from SEM
51+
"""
3652
e = self.noise_dist(self.noise_mean, self.noise_stds)
3753

3854
out_mat = np.zeros_like(e)
@@ -42,10 +58,25 @@ def generate_sample(self):
4258

4359
return out_mat
4460

45-
def generate_samples(self, n_samp):
61+
def generate_samples(self, n_samp: int) -> np.ndarray:
62+
"""Generate multiple samples from SEM.
63+
64+
Returns:
65+
Samples from SEM in shape (n_samp, n_dim)
66+
"""
4667
return np.array([self.generate_sample() for _ in range(n_samp)])
4768

48-
def generate_intervention(self, int_val):
69+
def generate_intervention(self, int_val: list[float | None]) -> np.ndarray:
70+
"""Generate ground truth intervention on variables.
71+
72+
Args:
73+
int_val:
74+
Intervenational values. Has shape (D,), but each position can
75+
be None to indicate no intervention in specific variable.
76+
77+
Returns:
78+
Sample generated under intervention
79+
"""
4980
e = self.noise_dist(self.noise_mean, self.noise_stds)
5081

5182
out_mat = np.zeros_like(e)
@@ -58,7 +89,26 @@ def generate_intervention(self, int_val):
5889

5990
return out_mat
6091

61-
def generate_int_dist(self, int_val, n_samp, return_mean=True):
92+
def generate_int_dist(
93+
self,
94+
int_val: list[float | None],
95+
n_samp: int,
96+
return_mean: bool = True
97+
) -> np.ndarray:
98+
"""Estimate ground truth interventional distribution.
99+
100+
Generates samples under intervention, and optionally returns mean.
101+
102+
Args:
103+
int_val:
104+
Intervenational values. Has shape (D,), but each position can
105+
be None to indicate no intervention in specific variable.
106+
n_samp: Number of samples used to estimate distribution
107+
return_mean: Whether mean of samples should be taken
108+
109+
Returns:
110+
Interventional samples, or mean of interventional samples
111+
"""
62112
samples = []
63113
for _ in range(n_samp):
64114
samples.append(self.generate_intervention(int_val))
@@ -68,8 +118,15 @@ def generate_int_dist(self, int_val, n_samp, return_mean=True):
68118
else:
69119
return np.array(samples)
70120

71-
def generate_ctf_obs(self):
72-
"""Generates an obs from specified SEM, return both obs and noise."""
121+
def generate_ctf_obs(self) -> tuple[np.ndarray, np.ndarray]:
122+
"""Generate an sample from specified SEM, return both obs and noise.
123+
124+
Return of noise that generated sample can be used to generate ground
125+
truth values for counterfactual queries.
126+
127+
Returns:
128+
Sample from SEM, and noise that generated sample
129+
"""
73130
e = self.noise_dist(self.noise_mean, self.noise_stds)
74131

75132
out_mat = np.zeros_like(e)
@@ -79,8 +136,22 @@ def generate_ctf_obs(self):
79136

80137
return out_mat, e
81138

82-
def generate_counterfactual(self, e, ctf_val):
139+
def generate_counterfactual(
140+
self,
141+
e: np.ndarray,
142+
ctf_val: list[float | None]
143+
) -> np.ndarray:
144+
"""Generate ground truth counterfactual outcome.
83145
146+
Args:
147+
e: Noise used to generate original sample of interest
148+
ctf_val:
149+
Counterfactual values. Has shape (D,), but each position can
150+
be None to indicate no counterfactual in specific variable.
151+
152+
Returns:
153+
Counterfactual outcome of sample
154+
"""
84155
out_mat = np.zeros_like(e)
85156

86157
for i, row in enumerate(self.adj_mat):
@@ -91,29 +162,46 @@ def generate_counterfactual(self, e, ctf_val):
91162

92163
return out_mat
93164

94-
def get_carefl_ds(self, n_samp):
165+
def get_carefl_ds(
166+
self,
167+
n_samp: int
168+
) -> tuple[np.ndarray, None, np.ndarray]:
169+
"""Generate dataset using SEM.
170+
171+
Args:
172+
n_samp: Number of samples in dataset
173+
174+
Returns:
175+
Samples, unused, and adjacency matrix
176+
"""
95177
X = self.generate_samples(n_samp)
96178

97-
# Generate binary adjacency matrix
98-
cfl_adj_mat = (self.adj_mat != 0).astype(int)
99-
np.fill_diagonal(cfl_adj_mat, 0)
179+
return X, None, self.get_adj_mat()
100180

101-
return X, None, cfl_adj_mat
181+
def get_adj_mat(self) -> np.ndarray:
182+
"""Return adjacency matrix associated with SEM."""
183+
bin_adj_mat = (self.adj_mat != 0).astype(int)
184+
np.fill_diagonal(bin_adj_mat, 0)
185+
return bin_adj_mat
102186

103187

104-
class RandomSEM:
105-
"""Initializes a random LinAddSEM."""
188+
class RandomSEM(LinAddSEM):
189+
"""Initializes a LinAddSEM with randomly sampled coefficients."""
106190

107-
def __init__(self, dimension, noise_mean_param=(-2, 2),
108-
noise_std_param=(1, 10), adj_gen_param=(-2, 2)):
109-
"""Initialize SEM.
191+
def __init__(
192+
self,
193+
dimension: int,
194+
noise_mean_param: tuple[float, float] = (-2, 2),
195+
noise_std_param: tuple[float, float] = (1, 10),
196+
adj_gen_param: tuple[float, float] = (-2, 2)
197+
):
198+
"""Initialize SEM with random coefficients.
110199
111200
Args:
112-
dimension (int): Size of the graph.
113-
noise_mean_param (float, float): Parameters to generate noise mean.
114-
noise_std_param (float, float): Parameters to generate noise std.
115-
adj_gen_param (float, float): Parameters to generate adjacency
116-
weight matrix.
201+
dimension: Size of the graph
202+
noise_mean_param: Parameters to generate noise mean
203+
noise_std_param: Parameters to generate noise std
204+
adj_gen_param: Parameters to generate adjacency weight matrix
117205
"""
118206
self.n_var = dimension
119207

@@ -124,33 +212,33 @@ def __init__(self, dimension, noise_mean_param=(-2, 2),
124212
adj_mat = uniform(*adj_gen_param, size=(dimension, dimension))
125213
self.adj_mat = np.tril(adj_mat)
126214

127-
self.sem = LinAddSEM(self.noise_means, self.noise_stds, self.adj_mat)
128-
129-
def generate_samples(self, n_samples):
130-
return self.sem.generate_samples(n_samples)
215+
super().__init__(self.noise_means, self.noise_stds, self.adj_mat)
131216

132-
def generate_int_dist(self, int_val, n_samples):
133-
return self.sem.generate_int_dist(int_val, n_samples)
134217

135-
def get_adj_mat(self):
136-
bin_adj_mat = (self.adj_mat != 0).astype(int)
137-
np.fill_diagonal(bin_adj_mat, 0)
138-
return bin_adj_mat
139-
140-
141-
class SparseSEM:
218+
class SparseSEM(LinAddSEM):
142219
"""Initializes a LinAddSEM with many independencies."""
143220

144-
def __init__(self, dimension, noise_mean_param=(-1, 1),
145-
noise_std_param=(1, 1), adj_gen_param=(-2, 2)):
146-
"""Initialize SEM.
221+
def __init__(
222+
self,
223+
dimension: int,
224+
noise_mean_param: tuple[int, int] = (-1, 1),
225+
noise_std_param: tuple[int, int] = (1, 1),
226+
adj_gen_param: tuple[int, int] = (-2, 2)
227+
):
228+
"""Initialize a SEM with a highly sparse adjacency.
147229
148230
Args:
149-
dimension (int): Size of the graph.
150-
noise_mean_param (float, float): Parameters to generate noise mean.
151-
noise_std_param (float, float): Parameters to generate noise std.
152-
adj_gen_param (float, float): Parameters to generate adjacency
153-
weight matrix.
231+
dimension: Size of the graph
232+
noise_mean_param:
233+
Range of uniform distribution used to generate noise
234+
distribution means.
235+
noise_std_param:
236+
Range of uniform distribution used to generate noise
237+
distribution standard deviations.
238+
adj_gen_param:
239+
Range of uniform distribution used to generate DAG edge
240+
coefficients. Note that values less than 1.5 in absolute
241+
value are rounded to zero.
154242
"""
155243
self.n_var = dimension
156244

@@ -167,44 +255,34 @@ def __init__(self, dimension, noise_mean_param=(-1, 1),
167255

168256
self.adj_mat = adj_mat
169257

170-
self.sem = LinAddSEM(self.noise_means, self.noise_stds, self.adj_mat)
171-
172-
def generate_samples(self, n_samples):
173-
return self.sem.generate_samples(n_samples)
174-
175-
def generate_int_dist(self, int_val, n_samples, return_mean=True):
176-
return self.sem.generate_int_dist(int_val, n_samples, return_mean)
177-
178-
def generate_ctf_obs(self):
179-
return self.sem.generate_ctf_obs()
258+
super().__init__(self.noise_means, self.noise_stds, self.adj_mat)
180259

181-
def generate_counterfactual(self, e, ctf_val):
182-
return self.sem.generate_counterfactual(e, ctf_val)
183-
184-
def get_adj_mat(self):
185-
bin_adj_mat = (self.adj_mat != 0).astype(int)
186-
np.fill_diagonal(bin_adj_mat, 0)
187-
return bin_adj_mat
188-
189260

190261
class CustomSyntheticDatasetDensity(Dataset):
191-
def __init__(self, X, device='cpu'):
262+
"""PyTorch Dataset wrapper for Causal SEMs."""
263+
264+
def __init__(self, X: np.ndarray, device: str = 'cpu'):
265+
"""Initialize torch dataset used to wrap causal SEMs."""
192266
self.device = device
193267
self.x = torch.from_numpy(X).to(device)
194268
self.len = self.x.shape[0]
195269
self.data_dim = self.x.shape[1]
196270

197-
def get_dims(self):
271+
def get_dims(self) -> int:
272+
"""Get feature dimensionality of data."""
198273
return self.data_dim
199274

200-
def __len__(self):
275+
def __len__(self) -> int:
276+
"""Return length of dataset."""
201277
return self.len
202278

203-
def __getitem__(self, index):
279+
def __getitem__(self, index: int) -> torch.Tensor:
280+
"""Return single datum from dataset."""
204281
return self.x[index]
205282

206-
def get_metadata(self):
283+
def get_metadata(self) -> dict:
284+
"""Return dataset statistics."""
207285
return {
208286
'n': self.len,
209287
'data_dim': self.data_dim,
210-
}
288+
}

data/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def split_dataset(
1212
data: np.ndarray,
1313
split_ratio: tuple[float, float, float]
1414
) -> DSTuple:
15-
"""Splits data into train, validation, and test splits.
15+
"""Split data into train, validation, and test splits.
1616
1717
Args:
1818
data: Dataset to split. Assumes first dimension is sample dimension.
@@ -29,7 +29,7 @@ def split_dataset(
2929

3030

3131
def standardize_data(data: np.ndarray):
32-
"""Standardizes data.
32+
"""Standardize data.
3333
3434
Args:
3535
data: Data of dimension (n_samples, n_features).

data/make_adj_mtx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def generate_adj_mat_uniform(data_dim: int, threshold: float) -> np.ndarray:
5-
"""Generates adjacency matrix with uniform sparsity.
5+
"""Generate adjacency matrix with uniform sparsity.
66
77
Args:
88
data_dim: Dimension of data.

experiments/synthetic_causality/config/baseline.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@ carefl:
99
flow:
1010
nl: 5
1111
nh: 10
12-
batch_norm: false
1312
prior_dist: 'laplace'
14-
# for CL
15-
# scale_base: true
16-
# shift_base: true
17-
# scale: true
1813

1914

2015
training:
@@ -23,6 +18,7 @@ training:
2318
split: .8
2419
seed: 0
2520
batch_size: 32
21+
early_stop_patience: 25
2622

2723

2824
optim:

0 commit comments

Comments
 (0)