Skip to content

Commit 4a485fb

Browse files
committed
Add type annotations
1 parent efa5146 commit 4a485fb

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

torchkde/modules.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional, Union
12
import torch
23
from torch import nn
34

@@ -31,11 +32,24 @@ class KernelDensity(nn.Module):
3132
def __init__(
3233
self,
3334
*,
34-
bandwidth=1.0,
35-
algorithm="standard",
36-
kernel="gaussian",
37-
kernel_kwargs=None
38-
):
35+
bandwidth: Union[float, str] = 1.0,
36+
algorithm: str = "standard",
37+
kernel: str = "gaussian",
38+
kernel_kwargs: dict = None
39+
) -> None:
40+
"""Initialize the KernelDensity estimator.
41+
42+
Parameters
43+
----------
44+
bandwidth : float or str, default=1.0
45+
The bandwidth of the kernel.
46+
algorithm : str, default="standard"
47+
The algorithm to use for kernel density estimation.
48+
kernel : str, default="gaussian"
49+
The kernel to use for density estimation.
50+
kernel_kwargs : dict, optional
51+
Additional keyword arguments for the kernel.
52+
"""
3953
if not isinstance(bandwidth, str):
4054
assert bandwidth > 0, "Bandwidth must be positive."
4155
self.bandwidth = bandwidth**2 # square the bandwidth to match sklearn's implementation
@@ -63,7 +77,10 @@ def __init__(
6377
raise ValueError(f"Bandwidth {bandwidth} not supported")
6478

6579

66-
def fit(self, X, sample_weight=None):
80+
def fit(self,
81+
X: torch.Tensor,
82+
sample_weight: Optional[torch.Tensor] = None
83+
) -> 'KernelDensity':
6784
"""Fit the Kernel Density model on the data.
6885
6986
Parameters
@@ -98,7 +115,7 @@ def fit(self, X, sample_weight=None):
98115

99116
return self
100117

101-
def score_samples(self, X):
118+
def score_samples(self, X: torch.Tensor) -> torch.Tensor:
102119
"""Compute the log-likelihood of each sample under the model.
103120
104121
Parameters
@@ -136,7 +153,7 @@ def score_samples(self, X):
136153
return log_density
137154

138155

139-
def score(self, X):
156+
def score(self, X: torch.Tensor) -> float:
140157
"""Compute the total log-likelihood under the model.
141158
142159
Parameters
@@ -154,7 +171,7 @@ def score(self, X):
154171
"""
155172
return self.score_samples(X).sum()
156173

157-
def sample(self, n_samples=1):
174+
def sample(self, n_samples: int = 1) -> torch.Tensor:
158175
"""Generate random samples from the model.
159176
160177
Parameters
@@ -178,4 +195,5 @@ def sample(self, n_samples=1):
178195
X = self.bandwidth * torch.randn(n_samples, data.shape[1]) + data[idxs]
179196

180197
return ensure_two_dimensional(X)
198+
181199

0 commit comments

Comments
 (0)