Skip to content

Commit edd19b7

Browse files
Merge pull request #12 from amreis/improve-stability
Improve stability and composability
2 parents 171a43b + b230734 commit edd19b7

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

torchkde/modules.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from .utils import ensure_two_dimensional, check_if_mat
66
from .algorithms import RootTree, SUPPORTED_ALGORITHMS
77
from .bandwidths import SUPPORTED_BANDWIDTHS, compute_bandwidth
8-
from .kernels import (GaussianKernel,
9-
EpanechnikovKernel,
10-
ExponentialKernel,
11-
TopHatKernel,
8+
from .kernels import (GaussianKernel,
9+
EpanechnikovKernel,
10+
ExponentialKernel,
11+
TopHatKernel,
1212
VonMisesFisherKernel,
1313
SUPPORTED_KERNELS)
1414

@@ -28,7 +28,7 @@
2828

2929

3030
class KernelDensity(nn.Module):
31-
"""Analag to the KernelDensity class in sklearn.neighbors
31+
"""Analag to the KernelDensity class in sklearn.neighbors
3232
(see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/neighbors/_kde.py)."""
3333

3434
def __init__(
@@ -37,7 +37,8 @@ def __init__(
3737
bandwidth: Union[float, str] = 1.0,
3838
algorithm: str = "standard",
3939
kernel: str = "gaussian",
40-
kernel_kwargs: dict = None
40+
kernel_kwargs: dict = None,
41+
eps: float = 0.0,
4142
) -> None:
4243
"""Initialize the KernelDensity estimator.
4344
@@ -51,7 +52,10 @@ def __init__(
5152
The kernel to use for density estimation.
5253
kernel_kwargs : dict, optional
5354
Additional keyword arguments for the kernel.
55+
eps: float, optional
56+
Small positive clamp for densities before log to avoid -inf.
5457
"""
58+
super().__init__()
5559
if not isinstance(bandwidth, str):
5660
assert bandwidth > 0, "Bandwidth must be positive."
5761
self.bandwidth = bandwidth**2 # square the bandwidth to match sklearn's implementation
@@ -69,19 +73,20 @@ def __init__(
6973
self.device = None
7074
self.n_features = None
7175
self.data = None
76+
self.eps = eps
7277

7378
if algorithm not in SUPPORTED_ALGORITHMS:
7479
raise ValueError(f"Algorithm {algorithm} not supported")
75-
80+
7681
if kernel not in SUPPORTED_KERNELS:
7782
raise ValueError(f"Kernel {kernel} not supported")
7883

7984
if not isinstance(bandwidth, (float, torch.Tensor)) and bandwidth not in SUPPORTED_BANDWIDTHS:
8085
raise ValueError(f"Bandwidth {bandwidth} not supported")
8186

8287

83-
def fit(self,
84-
X: torch.Tensor,
88+
def fit(self,
89+
X: torch.Tensor,
8590
sample_weight: Optional[torch.Tensor] = None
8691
) -> 'KernelDensity':
8792
"""Fit the Kernel Density model on the data.
@@ -141,7 +146,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
141146
"""
142147
assert self.is_fitted, "Model must be fitted before scoring samples."
143148
assert X.device == self.device, "Device of the query data must be on the same device as the data for fitting the estimator."
144-
149+
145150
n_samples = X.shape[0]
146151
# Compute log-density estimation with a kernel function
147152
log_density = []
@@ -156,7 +161,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
156161
density = ((self.sample_weight * kernel_values).sum(-1) * self.kernel_module.norm_constant) \
157162
/ self.sample_weight.sum()
158163
# Compute the log-density
159-
log_density.append(density.log())
164+
log_density.append(density.clamp(min=self.eps).log())
160165

161166
# Convert the list of log-density values into a tensor
162167
log_density = torch.cat(log_density, dim=0)

0 commit comments

Comments
 (0)