Skip to content

Commit 356d8d0

Browse files
[Sample KDE] add a test
1 parent ddfc1bd commit 356d8d0

File tree

1 file changed

+210
-160
lines changed

1 file changed

+210
-160
lines changed

tests/test_kde.py

Lines changed: 210 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,160 +1,210 @@
1-
"""Tests that check whether the kernel density estimator behaves as expected."""
2-
3-
from itertools import product
4-
from functools import partial
5-
import math
6-
import random
7-
import unittest
8-
9-
import torch
10-
import numpy as np
11-
from scipy.special import gamma
12-
from torch.autograd import gradcheck
13-
14-
from torchkde.kernels import *
15-
from torchkde.modules import KernelDensity
16-
from torchkde.bandwidths import SUPPORTED_BANDWIDTHS
17-
18-
BANDWIDTHS = [1.0, 5.0] + SUPPORTED_BANDWIDTHS
19-
DIMS = [1, 2]
20-
TOLERANCE = 1e-1
21-
WEIGHTS = [False, True]
22-
23-
DEVICES = ["cpu"]
24-
25-
N1 = 100
26-
N2 = 10
27-
GRID_N = 1000
28-
GRID_RG = 100
29-
30-
31-
class TestKDE(unittest.TestCase):
32-
def setUp(self):
33-
torch.manual_seed(0)
34-
random.seed(0)
35-
np.random.seed(0)
36-
torch.backends.cudnn.deterministic = True
37-
torch.backends.cudnn.benchmark = False
38-
39-
def test_integral(self):
40-
"""Test that the kernel density estimator integrates to 1."""
41-
for kernel_str, bandwidth, dim, weights in product(SUPPORTED_KERNELS, BANDWIDTHS, DIMS, WEIGHTS):
42-
if kernel_str == 'von-mises-fisher':
43-
# Skip the von-mises-fisher kernel, must be handled differently
44-
# as it is not defined in the same way as the other kernels
45-
continue
46-
X = sample_from_gaussian(dim, N1)
47-
# Fit a kernel density estimator to the data
48-
kde = KernelDensity(bandwidth=bandwidth, kernel=kernel_str)
49-
if weights:
50-
weights = torch.rand((N1,)).exp()
51-
_ = kde.fit(X, sample_weight=weights)
52-
else:
53-
_ = kde.fit(X)
54-
# assess whether the kernel integrates to 1
55-
# evaluate the kernel density estimator at a grid of 2D points
56-
# Create ranges for each dimension
57-
ranges = [torch.linspace(-GRID_RG, GRID_RG, GRID_N) for _ in range(dim)]
58-
# Create the d-dimensional meshgrid
59-
meshgrid = torch.meshgrid(*ranges, indexing='ij') # 'ij' indexing for Cartesian coordinates
60-
61-
# Convert meshgrid to a single tensor of shape (n_points, d)
62-
grid_points = torch.stack(meshgrid, dim=-1).reshape(-1, dim)
63-
probs = kde.score_samples(grid_points).exp()
64-
delta = (GRID_RG * 2) / GRID_N
65-
integral = probs.sum() * (delta**dim)
66-
self.assertTrue((integral - 1.0).abs() < TOLERANCE,
67-
f"""Kernel {kernel_str}, for dimensionality {str(dim)}
68-
and bandwidth {str(bandwidth)} does not integrate to 1.""")
69-
70-
def test_vmf_integral(self): # von-Mises-Fisher must be tested differently from other kernels
71-
for bandwidth, dim, weights in product(BANDWIDTHS, DIMS, WEIGHTS):
72-
if dim == 1 or type(bandwidth) == str:
73-
# Skip the von-mises-fisher kernel, must be handled differently
74-
# as it is not defined in the same way as the other kernels
75-
continue
76-
X = sample_from_gaussian(dim, N1)
77-
X = X / X.norm(dim=1, keepdim=True) # project onto sphere
78-
# Fit a kernel density estimator to the data
79-
kde = KernelDensity(bandwidth=bandwidth, kernel='von-mises-fisher')
80-
if weights:
81-
weights = torch.rand((N1,)).exp()
82-
_ = kde.fit(X, sample_weight=weights)
83-
else:
84-
_ = kde.fit(X)
85-
# assess whether the kernel integrates to 1
86-
87-
# Create the d-dimensional meshgrid
88-
mesh_samples = sample_from_gaussian(dim, GRID_N**dim)
89-
mesh_samples = mesh_samples / mesh_samples.norm(dim=1, keepdim=True) # project onto sphere
90-
91-
probs = kde.score_samples(mesh_samples).exp()
92-
surface_area = 2 * math.pi ** (dim / 2) / gamma(dim / 2)
93-
integral = probs.mean() * surface_area
94-
self.assertTrue((integral - 1.0).abs() < TOLERANCE,
95-
f"""Kernel von-mises-fisher, for dimensionality {str(dim)}
96-
and bandwidth {str(bandwidth)} does not integrate to 1.""")
97-
98-
99-
def test_diffble(self, bandwidth=torch.tensor(1.0), eps=1e-07):
100-
"""Test that the kernel density estimator is differentiable."""
101-
for kernel_str, dim in product(SUPPORTED_KERNELS, DIMS):
102-
def fit_and_eval(X, X_new, bandwidth):
103-
kde = KernelDensity(bandwidth=bandwidth, kernel=kernel_str)
104-
_ = kde.fit(X)
105-
return kde.score_samples(X_new)
106-
X = sample_from_gaussian(dim, N1).to(torch.float64) # relevant for the gradient check to convert to double
107-
X_new = sample_from_gaussian(dim, N2).to(torch.float64)
108-
109-
if kernel_str == "von-mises-fisher": # normalization required
110-
if dim == 1:
111-
continue
112-
# Project the data onto the unit sphere
113-
X = X / X.norm(dim=1, keepdim=True)
114-
X_new = X_new / X_new.norm(dim=1, keepdim=True)
115-
116-
bandwidth = bandwidth.to(torch.float64)
117-
118-
X.requires_grad = True
119-
X_new.requires_grad = False
120-
bandwidth.requires_grad = False
121-
122-
# Check that the kernel density estimator is differentiable w.r.t. the training data
123-
fnc = partial(fit_and_eval, X_new=X_new, bandwidth=bandwidth)
124-
self.assertTrue(gradcheck(lambda X_: fnc(X=X_), (X,), raise_exception=False, eps=eps),
125-
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t training data.""")
126-
127-
if not kernel_str == "von-mises-fisher": # normalization required
128-
X.requires_grad = False
129-
X_new.requires_grad = False
130-
bandwidth.requires_grad = True
131-
132-
# Check that the kernel density estimator is differentiable w.r.t. the bandwidth
133-
fnc = partial(fit_and_eval, X=X, X_new=X_new)
134-
self.assertTrue(gradcheck(lambda bandwidth_: fnc(bandwidth=bandwidth_), (bandwidth,), raise_exception=False, eps=eps),
135-
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t. the bandwidth.""")
136-
137-
X.requires_grad = False
138-
X_new.requires_grad = True
139-
bandwidth.requires_grad = False
140-
141-
# Check that the kernel density estimator is differentiable w.r.t. the evaluation data
142-
fnc = partial(fit_and_eval, X=X, bandwidth=bandwidth)
143-
self.assertTrue(gradcheck(lambda X_new_: fnc(X_new=X_new_), (X_new,), raise_exception=False, eps=eps),
144-
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t evaluation data.""")
145-
146-
147-
def sample_from_gaussian(dim, N):
148-
# sample data from a normal distribution
149-
mean = torch.zeros(dim)
150-
covariance_matrix = torch.eye(dim)
151-
152-
# Create the multivariate Gaussian distribution
153-
multivariate_normal = torch.distributions.MultivariateNormal(mean, covariance_matrix)
154-
X = multivariate_normal.sample((N,))
155-
return X
156-
157-
158-
if __name__ == "__main__":
159-
torch.manual_seed(0) # ensure reproducibility
160-
unittest.main()
1+
"""Tests that check whether the kernel density estimator behaves as expected."""
2+
3+
from itertools import product
4+
from functools import partial
5+
import math
6+
import random
7+
import unittest
8+
9+
import torch
10+
from torch import distributions as dist
11+
import numpy as np
12+
from scipy.special import gamma
13+
from torch.autograd import gradcheck
14+
15+
import torchkde
16+
from torchkde.kernels import *
17+
from torchkde.modules import KernelDensity
18+
from torchkde.bandwidths import SUPPORTED_BANDWIDTHS
19+
20+
BANDWIDTHS = [1.0, 5.0] + SUPPORTED_BANDWIDTHS
21+
DIMS = [1, 2]
22+
TOLERANCE = 1e-1
23+
WEIGHTS = [False, True]
24+
25+
DEVICES = ["cpu"]
26+
27+
N1 = 100
28+
N2 = 10
29+
GRID_N = 1000
30+
GRID_RG = 100
31+
32+
33+
class TestKDE(unittest.TestCase):
34+
def setUp(self):
35+
torch.manual_seed(0)
36+
random.seed(0)
37+
np.random.seed(0)
38+
torch.backends.cudnn.deterministic = True
39+
torch.backends.cudnn.benchmark = False
40+
41+
def test_integral(self):
42+
"""Test that the kernel density estimator integrates to 1."""
43+
for kernel_str, bandwidth, dim, weights in product(SUPPORTED_KERNELS, BANDWIDTHS, DIMS, WEIGHTS):
44+
if kernel_str == 'von-mises-fisher':
45+
# Skip the von-mises-fisher kernel, must be handled differently
46+
# as it is not defined in the same way as the other kernels
47+
continue
48+
X = sample_from_gaussian(dim, N1)
49+
# Fit a kernel density estimator to the data
50+
kde = KernelDensity(bandwidth=bandwidth, kernel=kernel_str)
51+
if weights:
52+
weights = torch.rand((N1,)).exp()
53+
_ = kde.fit(X, sample_weight=weights)
54+
else:
55+
_ = kde.fit(X)
56+
# assess whether the kernel integrates to 1
57+
# evaluate the kernel density estimator at a grid of 2D points
58+
# Create ranges for each dimension
59+
ranges = [torch.linspace(-GRID_RG, GRID_RG, GRID_N) for _ in range(dim)]
60+
# Create the d-dimensional meshgrid
61+
meshgrid = torch.meshgrid(*ranges, indexing='ij') # 'ij' indexing for Cartesian coordinates
62+
63+
# Convert meshgrid to a single tensor of shape (n_points, d)
64+
grid_points = torch.stack(meshgrid, dim=-1).reshape(-1, dim)
65+
probs = kde.score_samples(grid_points).exp()
66+
delta = (GRID_RG * 2) / GRID_N
67+
integral = probs.sum() * (delta**dim)
68+
self.assertTrue((integral - 1.0).abs() < TOLERANCE,
69+
f"""Kernel {kernel_str}, for dimensionality {str(dim)}
70+
and bandwidth {str(bandwidth)} does not integrate to 1.""")
71+
72+
def test_vmf_integral(self): # von-Mises-Fisher must be tested differently from other kernels
73+
for bandwidth, dim, weights in product(BANDWIDTHS, DIMS, WEIGHTS):
74+
if dim == 1 or type(bandwidth) == str:
75+
# Skip the von-mises-fisher kernel, must be handled differently
76+
# as it is not defined in the same way as the other kernels
77+
continue
78+
X = sample_from_gaussian(dim, N1)
79+
X = X / X.norm(dim=1, keepdim=True) # project onto sphere
80+
# Fit a kernel density estimator to the data
81+
kde = KernelDensity(bandwidth=bandwidth, kernel='von-mises-fisher')
82+
if weights:
83+
weights = torch.rand((N1,)).exp()
84+
_ = kde.fit(X, sample_weight=weights)
85+
else:
86+
_ = kde.fit(X)
87+
# assess whether the kernel integrates to 1
88+
89+
# Create the d-dimensional meshgrid
90+
mesh_samples = sample_from_gaussian(dim, GRID_N**dim)
91+
mesh_samples = mesh_samples / mesh_samples.norm(dim=1, keepdim=True) # project onto sphere
92+
93+
probs = kde.score_samples(mesh_samples).exp()
94+
surface_area = 2 * math.pi ** (dim / 2) / gamma(dim / 2)
95+
integral = probs.mean() * surface_area
96+
self.assertTrue((integral - 1.0).abs() < TOLERANCE,
97+
f"""Kernel von-mises-fisher, for dimensionality {str(dim)}
98+
and bandwidth {str(bandwidth)} does not integrate to 1.""")
99+
100+
101+
def test_diffble(self, bandwidth=torch.tensor(1.0), eps=1e-07):
102+
"""Test that the kernel density estimator is differentiable."""
103+
for kernel_str, dim in product(SUPPORTED_KERNELS, DIMS):
104+
def fit_and_eval(X, X_new, bandwidth):
105+
kde = KernelDensity(bandwidth=bandwidth, kernel=kernel_str)
106+
_ = kde.fit(X)
107+
return kde.score_samples(X_new)
108+
X = sample_from_gaussian(dim, N1).to(torch.float64) # relevant for the gradient check to convert to double
109+
X_new = sample_from_gaussian(dim, N2).to(torch.float64)
110+
111+
if kernel_str == "von-mises-fisher": # normalization required
112+
if dim == 1:
113+
continue
114+
# Project the data onto the unit sphere
115+
X = X / X.norm(dim=1, keepdim=True)
116+
X_new = X_new / X_new.norm(dim=1, keepdim=True)
117+
118+
bandwidth = bandwidth.to(torch.float64)
119+
120+
X.requires_grad = True
121+
X_new.requires_grad = False
122+
bandwidth.requires_grad = False
123+
124+
# Check that the kernel density estimator is differentiable w.r.t. the training data
125+
fnc = partial(fit_and_eval, X_new=X_new, bandwidth=bandwidth)
126+
self.assertTrue(gradcheck(lambda X_: fnc(X=X_), (X,), raise_exception=False, eps=eps),
127+
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t training data.""")
128+
129+
if not kernel_str == "von-mises-fisher": # normalization required
130+
X.requires_grad = False
131+
X_new.requires_grad = False
132+
bandwidth.requires_grad = True
133+
134+
# Check that the kernel density estimator is differentiable w.r.t. the bandwidth
135+
fnc = partial(fit_and_eval, X=X, X_new=X_new)
136+
self.assertTrue(gradcheck(lambda bandwidth_: fnc(bandwidth=bandwidth_), (bandwidth,), raise_exception=False, eps=eps),
137+
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t. the bandwidth.""")
138+
139+
X.requires_grad = False
140+
X_new.requires_grad = True
141+
bandwidth.requires_grad = False
142+
143+
# Check that the kernel density estimator is differentiable w.r.t. the evaluation data
144+
fnc = partial(fit_and_eval, X=X, bandwidth=bandwidth)
145+
self.assertTrue(gradcheck(lambda X_new_: fnc(X_new=X_new_), (X_new,), raise_exception=False, eps=eps),
146+
f"""Kernel {kernel_str}, for dimensionality {str(dim)} is not differentiable w.r.t evaluation data.""")
147+
148+
def test_sampling_adheres_to_weights(self,):
149+
# test if the sample_weights passed in fit(X, sample_weights=...) are respected on sampling
150+
n_samples=2000
151+
152+
# create a GMM with 2 components with weights pi
153+
pi = torch.tensor([0.9, 0.05]) # 90 % for component 1, 0.5 % for component 2 (does not need to sum to 1)
154+
loc1 = torch.tensor([10.,10])
155+
cov1 = torch.diag(torch.tensor([1.,1]))
156+
loc2 = torch.tensor([0.,0])
157+
cov2 = torch.diag(torch.tensor([1.,1]))
158+
weights1 = torch.ones((n_samples,))*pi[0]
159+
weights2 = torch.ones((n_samples,))*pi[1]
160+
161+
locs = torch.stack([loc1, loc2]) # Shape: [n_components, event_shape] = [2, 2]
162+
covs = torch.stack([cov1, cov2]) # Shape: [n_components, event_shape, event_shape] = [2, 2, 2]
163+
164+
component_distribution = dist.multivariate_normal.MultivariateNormal(
165+
loc=locs,
166+
covariance_matrix=covs
167+
)
168+
169+
# 2. Create the mixing distribution (Categorical)
170+
# pi is interpreted as weights in linear-scale
171+
mixing_distribution = dist.Categorical(probs=pi)
172+
173+
# 3. Create the MixtureSameFamily distribution
174+
# This combines the mixing and component distributions
175+
gmm = dist.MixtureSameFamily(
176+
mixture_distribution=mixing_distribution,
177+
component_distribution=component_distribution
178+
)
179+
180+
181+
X = component_distribution.sample((n_samples,))
182+
X1 = X[:,0,:]
183+
X2 = X[:,1,:]
184+
185+
kde = torchkde.KernelDensity(bandwidth=.5, kernel='gaussian') # create kde object with isotropic bandwidth matrix
186+
kde.fit(torch.concat((X1, X2), dim=0), sample_weight=torch.concat((weights1, weights2), dim=0)) # fit kde to weighted data
187+
188+
samples_from_kde = kde.sample(n_samples)
189+
samples_from_gmm = gmm.sample((n_samples,))
190+
191+
component_1_fraction_kde = torch.count_nonzero(torch.where(samples_from_kde[:,0] > 5., 1.0, 0.0)) / n_samples
192+
component_1_fraction_gmm = torch.count_nonzero(torch.where(samples_from_gmm[:,0] > 5., 1.0, 0.0)) / n_samples
193+
print(component_1_fraction_kde / component_1_fraction_gmm)
194+
self.assertTrue(0.9 < (component_1_fraction_kde / component_1_fraction_gmm) < 1.1, "Component weights must be considered on sampling.")
195+
196+
197+
def sample_from_gaussian(dim, N):
198+
# sample data from a normal distribution
199+
mean = torch.zeros(dim)
200+
covariance_matrix = torch.eye(dim)
201+
202+
# Create the multivariate Gaussian distribution
203+
multivariate_normal = torch.distributions.MultivariateNormal(mean, covariance_matrix)
204+
X = multivariate_normal.sample((N,))
205+
return X
206+
207+
208+
if __name__ == "__main__":
209+
torch.manual_seed(0) # ensure reproducibility
210+
unittest.main()

0 commit comments

Comments
 (0)