55from .utils import ensure_two_dimensional , check_if_mat
66from .algorithms import RootTree , SUPPORTED_ALGORITHMS
77from .bandwidths import SUPPORTED_BANDWIDTHS , compute_bandwidth
8- from .kernels import (GaussianKernel ,
9- EpanechnikovKernel ,
10- ExponentialKernel ,
11- TopHatKernel ,
8+ from .kernels import (GaussianKernel ,
9+ EpanechnikovKernel ,
10+ ExponentialKernel ,
11+ TopHatKernel ,
1212 VonMisesFisherKernel ,
1313 SUPPORTED_KERNELS )
1414
2828
2929
3030class KernelDensity (nn .Module ):
31- """Analag to the KernelDensity class in sklearn.neighbors
31+ """Analag to the KernelDensity class in sklearn.neighbors
3232 (see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/neighbors/_kde.py)."""
3333
3434 def __init__ (
@@ -37,7 +37,8 @@ def __init__(
3737 bandwidth : Union [float , str ] = 1.0 ,
3838 algorithm : str = "standard" ,
3939 kernel : str = "gaussian" ,
40- kernel_kwargs : dict = None
40+ kernel_kwargs : dict = None ,
41+ eps : float = 0.0 ,
4142 ) -> None :
4243 """Initialize the KernelDensity estimator.
4344
@@ -51,7 +52,10 @@ def __init__(
5152 The kernel to use for density estimation.
5253 kernel_kwargs : dict, optional
5354 Additional keyword arguments for the kernel.
55+ eps: float, optional
56+ Small positive clamp for densities before log to avoid -inf.
5457 """
58+ super ().__init__ ()
5559 if not isinstance (bandwidth , str ):
5660 assert bandwidth > 0 , "Bandwidth must be positive."
5761 self .bandwidth = bandwidth ** 2 # square the bandwidth to match sklearn's implementation
@@ -69,19 +73,20 @@ def __init__(
6973 self .device = None
7074 self .n_features = None
7175 self .data = None
76+ self .eps = eps
7277
7378 if algorithm not in SUPPORTED_ALGORITHMS :
7479 raise ValueError (f"Algorithm { algorithm } not supported" )
75-
80+
7681 if kernel not in SUPPORTED_KERNELS :
7782 raise ValueError (f"Kernel { kernel } not supported" )
7883
7984 if not isinstance (bandwidth , (float , torch .Tensor )) and bandwidth not in SUPPORTED_BANDWIDTHS :
8085 raise ValueError (f"Bandwidth { bandwidth } not supported" )
8186
8287
83- def fit (self ,
84- X : torch .Tensor ,
88+ def fit (self ,
89+ X : torch .Tensor ,
8590 sample_weight : Optional [torch .Tensor ] = None
8691 ) -> 'KernelDensity' :
8792 """Fit the Kernel Density model on the data.
@@ -141,7 +146,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
141146 """
142147 assert self .is_fitted , "Model must be fitted before scoring samples."
143148 assert X .device == self .device , "Device of the query data must be on the same device as the data for fitting the estimator."
144-
149+
145150 n_samples = X .shape [0 ]
146151 # Compute log-density estimation with a kernel function
147152 log_density = []
@@ -156,7 +161,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
156161 density = ((self .sample_weight * kernel_values ).sum (- 1 ) * self .kernel_module .norm_constant ) \
157162 / self .sample_weight .sum ()
158163 # Compute the log-density
159- log_density .append (density .log ())
164+ log_density .append (density .clamp ( min = self . eps ). log ())
160165
161166 # Convert the list of log-density values into a tensor
162167 log_density = torch .cat (log_density , dim = 0 )
0 commit comments