1+ from typing import Optional , Union
12import torch
23from torch import nn
34
@@ -31,11 +32,24 @@ class KernelDensity(nn.Module):
3132 def __init__ (
3233 self ,
3334 * ,
34- bandwidth = 1.0 ,
35- algorithm = "standard" ,
36- kernel = "gaussian" ,
37- kernel_kwargs = None
38- ):
35+ bandwidth : Union [float , str ] = 1.0 ,
36+ algorithm : str = "standard" ,
37+ kernel : str = "gaussian" ,
38+ kernel_kwargs : dict = None
39+ ) -> None :
40+ """Initialize the KernelDensity estimator.
41+
42+ Parameters
43+ ----------
44+ bandwidth : float or str, default=1.0
45+ The bandwidth of the kernel.
46+ algorithm : str, default="standard"
47+ The algorithm to use for kernel density estimation.
48+ kernel : str, default="gaussian"
49+ The kernel to use for density estimation.
50+ kernel_kwargs : dict, optional
51+ Additional keyword arguments for the kernel.
52+ """
3953 if not isinstance (bandwidth , str ):
4054 assert bandwidth > 0 , "Bandwidth must be positive."
4155 self .bandwidth = bandwidth ** 2 # square the bandwidth to match sklearn's implementation
@@ -63,7 +77,10 @@ def __init__(
6377 raise ValueError (f"Bandwidth { bandwidth } not supported" )
6478
6579
66- def fit (self , X , sample_weight = None ):
80+ def fit (self ,
81+ X : torch .Tensor ,
82+ sample_weight : Optional [torch .Tensor ] = None
83+ ) -> 'KernelDensity' :
6784 """Fit the Kernel Density model on the data.
6885
6986 Parameters
@@ -98,7 +115,7 @@ def fit(self, X, sample_weight=None):
98115
99116 return self
100117
101- def score_samples (self , X ) :
118+ def score_samples (self , X : torch . Tensor ) -> torch . Tensor :
102119 """Compute the log-likelihood of each sample under the model.
103120
104121 Parameters
@@ -136,7 +153,7 @@ def score_samples(self, X):
136153 return log_density
137154
138155
139- def score (self , X ) :
156+ def score (self , X : torch . Tensor ) -> float :
140157 """Compute the total log-likelihood under the model.
141158
142159 Parameters
@@ -154,7 +171,7 @@ def score(self, X):
154171 """
155172 return self .score_samples (X ).sum ()
156173
157- def sample (self , n_samples = 1 ) :
174+ def sample (self , n_samples : int = 1 ) -> torch . Tensor :
158175 """Generate random samples from the model.
159176
160177 Parameters
@@ -178,4 +195,5 @@ def sample(self, n_samples=1):
178195 X = self .bandwidth * torch .randn (n_samples , data .shape [1 ]) + data [idxs ]
179196
180197 return ensure_two_dimensional (X )
198+
181199
0 commit comments