Skip to content

Commit 9baf4c8

Browse files
committed
Minor adjustments before merge
1 parent 356d8d0 commit 9baf4c8

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

tests/test_kde.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,19 @@
2626

2727
N1 = 100
2828
N2 = 10
29+
N3 = 1000
30+
2931
GRID_N = 1000
3032
GRID_RG = 100
3133

34+
# parameters to test whether sampling adheres to weights
35+
COMPONENT_WEIGHTS = [0.9, 0.05]
36+
LOC1 = [10.0, 10.0]
37+
LOC2 = [0.0, 0.0]
38+
COV1 = [1.0, 1.0]
39+
COV2 = [1.0, 1.0]
40+
THRESHOLD = 5.0
41+
3242

3343
class TestKDE(unittest.TestCase):
3444
def setUp(self):
@@ -147,16 +157,15 @@ def fit_and_eval(X, X_new, bandwidth):
147157

148158
def test_sampling_adheres_to_weights(self,):
149159
# test if the sample_weights passed in fit(X, sample_weights=...) are respected on sampling
150-
n_samples=2000
151-
152160
# create a GMM with 2 components with weights pi
153-
pi = torch.tensor([0.9, 0.05]) # 90 % for component 1, 0.5 % for component 2 (does not need to sum to 1)
154-
loc1 = torch.tensor([10.,10])
155-
cov1 = torch.diag(torch.tensor([1.,1]))
156-
loc2 = torch.tensor([0.,0])
157-
cov2 = torch.diag(torch.tensor([1.,1]))
158-
weights1 = torch.ones((n_samples,))*pi[0]
159-
weights2 = torch.ones((n_samples,))*pi[1]
161+
pi = torch.tensor(COMPONENT_WEIGHTS) # 90 % for component 1, 0.5 % for component 2 (does not need to sum to 1)
162+
loc1 = torch.tensor(LOC1)
163+
cov1 = torch.diag(torch.tensor(COV1))
164+
loc2 = torch.tensor(LOC2)
165+
cov2 = torch.diag(torch.tensor(COV2))
166+
167+
weights1 = torch.ones((N3,))*pi[0]
168+
weights2 = torch.ones((N3,))*pi[1]
160169

161170
locs = torch.stack([loc1, loc2]) # Shape: [n_components, event_shape] = [2, 2]
162171
covs = torch.stack([cov1, cov2]) # Shape: [n_components, event_shape, event_shape] = [2, 2, 2]
@@ -166,32 +175,32 @@ def test_sampling_adheres_to_weights(self,):
166175
covariance_matrix=covs
167176
)
168177

169-
# 2. Create the mixing distribution (Categorical)
178+
# Create the mixing distribution (Categorical)
170179
# pi is interpreted as weights in linear-scale
171180
mixing_distribution = dist.Categorical(probs=pi)
172181

173-
# 3. Create the MixtureSameFamily distribution
182+
# Create the MixtureSameFamily distribution
174183
# This combines the mixing and component distributions
175184
gmm = dist.MixtureSameFamily(
176185
mixture_distribution=mixing_distribution,
177186
component_distribution=component_distribution
178187
)
179188

180189

181-
X = component_distribution.sample((n_samples,))
190+
X = component_distribution.sample((N3,))
182191
X1 = X[:,0,:]
183192
X2 = X[:,1,:]
184193

185194
kde = torchkde.KernelDensity(bandwidth=.5, kernel='gaussian') # create kde object with isotropic bandwidth matrix
186195
kde.fit(torch.concat((X1, X2), dim=0), sample_weight=torch.concat((weights1, weights2), dim=0)) # fit kde to weighted data
187196

188-
samples_from_kde = kde.sample(n_samples)
189-
samples_from_gmm = gmm.sample((n_samples,))
197+
samples_from_kde = kde.sample(N1)
198+
samples_from_gmm = gmm.sample((N1,))
199+
200+
component_1_fraction_kde = torch.count_nonzero(torch.where(samples_from_kde[:,0] > THRESHOLD, 1.0, 0.0)) / N3
201+
component_1_fraction_gmm = torch.count_nonzero(torch.where(samples_from_gmm[:,0] > THRESHOLD, 1.0, 0.0)) / N3
190202

191-
component_1_fraction_kde = torch.count_nonzero(torch.where(samples_from_kde[:,0] > 5., 1.0, 0.0)) / n_samples
192-
component_1_fraction_gmm = torch.count_nonzero(torch.where(samples_from_gmm[:,0] > 5., 1.0, 0.0)) / n_samples
193-
print(component_1_fraction_kde / component_1_fraction_gmm)
194-
self.assertTrue(0.9 < (component_1_fraction_kde / component_1_fraction_gmm) < 1.1, "Component weights must be considered on sampling.")
203+
self.assertTrue(1.0 - TOLERANCE < (component_1_fraction_kde / component_1_fraction_gmm) < 1.0 + TOLERANCE, "Component weights must be considered on sampling.")
195204

196205

197206
def sample_from_gaussian(dim, N):

torchkde/modules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,3 @@ def sample(self, n_samples: int = 1) -> torch.Tensor:
206206
X = self.bandwidth * torch.randn(n_samples, data.shape[1]) + data[idxs]
207207

208208
return ensure_two_dimensional(X)
209-

0 commit comments

Comments
 (0)