Skip to content

Commit 0c9ce71

Browse files
committed
Refactor kernels
1 parent f137857 commit 0c9ce71

File tree

1 file changed

+43
-46
lines changed

1 file changed

+43
-46
lines changed

torchkde/kernels.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def bandwidth(self):
2828

2929
@bandwidth.setter
3030
def bandwidth(self, bandwidth):
31+
self._norm_constant = None # reset normalization constant when bandwidth changes
3132
self._bandwidth = bandwidth
3233
# compute H^(-1/2)
3334
if check_if_mat(bandwidth):
@@ -41,32 +42,38 @@ def norm_constant(self):
4142
assert self.dim is not None, "Dimension not set."
4243
self._norm_constant = self._compute_norm_constant(self.dim)
4344
return self._norm_constant
45+
46+
def _bw_norm(self, dim):
47+
if check_if_mat(self._bandwidth):
48+
return torch.sqrt(torch.det(self._bandwidth))
49+
else:
50+
return self._bandwidth ** (dim / 2)
4451

4552
@abstractmethod
4653
def _compute_norm_constant(self, dim):
4754
pass
4855

4956
@abstractmethod
5057
def __call__(self, x1, x2):
58+
"""This __call__ must be called in child classes to clear the norm constant cache and check the inputs."""
5159
assert self.bandwidth is not None, "Bandwidth not set."
60+
assert x1.shape[-1] == x2.shape[-1], "Input data must have the same dimensionality."
61+
new_dim = x1.shape[-1]
62+
if new_dim != self.dim: # first call or change of dimensionality
63+
self.dim = new_dim
64+
self._norm_constant = None
5265

5366

5467
class GaussianKernel(Kernel):
5568
def __call__(self, x1, x2):
5669
super().__call__(x1, x2)
5770
differences = x1 - x2
58-
self.dim = differences.shape[-1]
59-
u = kernel_input(self.inv_bandwidth, differences)
71+
u = compute_u(self.inv_bandwidth, differences)
6072

6173
return torch.exp(-u/2)
6274

6375
def _compute_norm_constant(self, dim):
64-
if check_if_mat(self._bandwidth):
65-
# When bandwidth is a matrix, include sqrt(det(bandwidth))
66-
bw_norm = torch.sqrt(torch.det(self._bandwidth))
67-
else:
68-
# When bandwidth is a scalar, raise it to the dim/2
69-
bw_norm = self._bandwidth**(dim/2)
76+
bw_norm = self._bw_norm(dim)
7077
return 1 / ((2 * math.pi)**(dim/2) * bw_norm)
7178

7279

@@ -75,82 +82,72 @@ class TopHatKernel(Kernel):
7582
via a generalized Gaussian."""
7683
def __init__(self, beta=8):
7784
super().__init__()
78-
assert type(beta) == int, "beta must be an integer."
85+
assert isinstance(beta, int), "beta must be an integer."
7986
self.beta = beta
8087

8188
def __call__(self, x1, x2):
8289
super().__call__(x1, x2)
8390
differences = x1 - x2
84-
self.dim = differences.shape[-1]
85-
u = kernel_input(self.inv_bandwidth, differences)
91+
u = compute_u(self.inv_bandwidth, differences)
8692

8793
return torch.exp(-(u**self.beta)/2)
8894

8995
def _compute_norm_constant(self, dim):
90-
if check_if_mat(self._bandwidth):
91-
# When bandwidth is a matrix, include sqrt(det(bandwidth))
92-
bw_norm = torch.sqrt(torch.det(self._bandwidth))
93-
else:
94-
# When bandwidth is a scalar, raise it to the d/2
95-
bw_norm = self._bandwidth**(dim/2)
96+
bw_norm = self._bw_norm(dim)
9697
return (self.beta*gamma(dim/2))/(math.pi**(dim/2) * \
9798
gamma(dim/(2*self.beta)) * 2**(dim/(2*self.beta)) * bw_norm)
9899

99100

100101
class EpanechnikovKernel(Kernel):
101102
def __init__(self):
102103
super().__init__()
103-
self._intrinsic_norm_constant = None
104+
self._unit_ball_constant = None
104105

105106
def __call__(self, x1, x2):
107+
old_dim = self.dim
106108
super().__call__(x1, x2)
107109
differences = x1 - x2
108-
self.dim = differences.shape[-1]
109-
c = self.intrinsic_norm_constant
110-
u = kernel_input(self.inv_bandwidth, differences)
110+
if old_dim is not None and old_dim != differences.shape[-1]:
111+
self._unit_ball_constant = None
112+
c = self.unit_ball_constant
113+
u = compute_u(self.inv_bandwidth, differences)
111114

112115
return torch.where(u > 1, 0, c * (1 - u))
113116

114-
def _compute_intrinsic_norm_constant(self, dim):
117+
@Kernel.bandwidth.setter
118+
def bandwidth(self, bandwidth):
119+
Kernel.bandwidth.fset(self, bandwidth)
120+
self._unit_ball_constant = None # reset the cached constant when bandwidth changes
121+
122+
def _compute_unit_ball_constant(self, dim):
115123
return ((dim + 2)*gamma(dim/2 + 1))/(2*math.pi**(dim/2))
116124

117125
@property
118-
def intrinsic_norm_constant(self):
126+
def unit_ball_constant(self):
119127
"""Return the cached intrinsic normalization constant, computing it if necessary."""
120-
if self._intrinsic_norm_constant is None:
121-
self._intrinsic_norm_constant = self._compute_intrinsic_norm_constant(self.dim)
122-
return self._intrinsic_norm_constant
128+
if self._unit_ball_constant is None:
129+
self._unit_ball_constant = self._compute_unit_ball_constant(self.dim)
130+
return self._unit_ball_constant
123131

124132
def _compute_norm_constant(self, dim):
125-
if check_if_mat(self._bandwidth):
126-
# When bandwidth is a matrix, include sqrt(det(bandwidth))
127-
bw_norm = torch.sqrt(torch.det(self._bandwidth))
128-
else:
129-
# When bandwidth is a scalar, raise it to the dim/2
130-
bw_norm = self._bandwidth**(dim/2)
133+
bw_norm = self._bw_norm(dim)
131134
return 1 / bw_norm
132135

133136

134137
class ExponentialKernel(Kernel):
135138
def __call__(self, x1, x2):
136139
super().__call__(x1, x2)
137140
differences = x1 - x2
138-
self.dim = differences.shape[-1]
139-
u = kernel_input(self.inv_bandwidth, differences, exp=1)
141+
u = compute_u(self.inv_bandwidth, differences, exp=1)
140142

141143
return torch.exp(-u)
142144

143145
def _compute_norm_constant(self, dim):
144-
if check_if_mat(self._bandwidth):
145-
# When bandwidth is a matrix, include sqrt(det(bandwidth))
146-
bw_norm = torch.sqrt(torch.det(self._bandwidth))
147-
else:
148-
# When bandwidth is a scalar, raise it to the dim/2
149-
bw_norm = self._bandwidth**(dim/2)
146+
bw_norm = self._bw_norm(dim)
150147
return 1/(2**dim * bw_norm)
151148

152149

153-
def kernel_input(inv_bandwidth, x, exp=2):
150+
def compute_u(inv_bandwidth, x, exp=2):
154151
"""Compute the input to the kernel function."""
155152
if exp >= 2:
156153
if check_if_mat(inv_bandwidth):
@@ -167,12 +164,13 @@ def kernel_input(inv_bandwidth, x, exp=2):
167164
class VonMisesFisherKernel(Kernel):
168165
@Kernel.bandwidth.setter
169166
def bandwidth(self, bandwidth):
167+
Kernel.bandwidth.fset(self, bandwidth)
170168
# For vMF, the bandwidth is directly the concentration parameter.
171-
if type(bandwidth) == torch.Tensor:
172-
assert bandwidth.requires_grad == False, \
169+
if isinstance(bandwidth, torch.Tensor):
170+
assert not bandwidth.requires_grad, \
173171
"The bandwidth for the von Mises-Fisher kernel must not require gradients."
174-
bandwidth = bandwidth.item()
175-
assert type(bandwidth) == float or isinstance(bandwidth, torch.Tensor) and bandwidth.dim() == 0, \
172+
bandwidth = bandwidth.item() # input to iv function cannot handle tensors
173+
assert isinstance(bandwidth, float) or isinstance(bandwidth, torch.Tensor) and bandwidth.dim() == 0, \
176174
"The bandwidth for the von Mises-Fisher kernel must be a scalar."
177175
self._bandwidth = bandwidth
178176
self.inv_bandwidth = bandwidth
@@ -183,7 +181,6 @@ def __call__(self, x1, x2):
183181
assert torch.allclose(
184182
x_all.norm(dim=-1), torch.ones_like(x_all[..., 0]), atol=1e-5
185183
), "The von Mises-Fisher kernel assumes all data to lie on the unit sphere. Please normalize data."
186-
self.dim = x1.shape[-1]
187184

188185
return torch.exp(self._bandwidth * (x1 * x2).sum(dim=-1))
189186

0 commit comments

Comments
 (0)