Skip to content

Commit 24eb26c

Browse files
committed
2 parents 1a106a5 + 2de07c2 commit 24eb26c

File tree

4 files changed

+10
-2
lines changed

4 files changed

+10
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ The current implementation provides the following functionality:
7777
| Kernels | Gaussian, Epanechnikov, Exponential, Tophat Approximation, von Mises-Fisher (data must lie on the unit sphere) |
7878
| Tree Algorithms | Standard |
7979
| Bandwidths | Float (Isotropic bandwidth matrix), Scott, Silverman |
80+
| Devices | CPU, GPU |
8081

8182
</div>
8283

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "torch-kde"
7-
version = "0.1.3"
7+
version = "0.1.4"
88
description = "A differentiable implementation of kernel density estimation in PyTorch"
99
readme = "README_PyPI.md"
1010
license = { text = "MIT" }

tests/test_kde.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
TOLERANCE = 1e-1
2121
WEIGHTS = [False, True]
2222

23+
DEVICES = ["cpu"]
24+
2325
N1 = 100
2426
N2 = 10
2527
GRID_N = 1000

torchkde/modules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
self.algorithm = algorithm
6767
self.is_fitted = False
6868
self.n_samples = None
69+
self.device = None
6970
self.n_features = None
7071
self.data = None
7172

@@ -106,14 +107,17 @@ def fit(self,
106107
self.tree_.build(X)
107108
self.bandwidth = compute_bandwidth(X, self.bandwidth)
108109
self.kernel_module.bandwidth = self.bandwidth
110+
self.device = X.device
109111
if sample_weight is None:
110-
self.sample_weight = torch.ones(self.n_samples)
112+
self.sample_weight = torch.ones(self.n_samples).to(self.device)
111113
else:
112114
assert sample_weight.shape[0] == self.n_samples, "Sample weights must have the same length as the data."
115+
assert sample_weight.device == self.device, "Sample weights must be on the same device as the data."
113116
assert sample_weight.dim() == 1, "Sample weights must be one-dimensional."
114117
assert sample_weight.min() >= 0, "Sample weights must be non-negative."
115118
self.sample_weight = sample_weight
116119
self.is_fitted = True
120+
self.device = X.device
117121

118122
return self
119123

@@ -136,6 +140,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
136140
data.
137141
"""
138142
assert self.is_fitted, "Model must be fitted before scoring samples."
143+
assert X.device == self.device, "Device of the query data must be on the same device as the data for fitting the estimator."
139144

140145
n_samples = X.shape[0]
141146
# Compute log-density estimation with a kernel function

0 commit comments

Comments
 (0)