Skip to content

Commit 7867372

Browse files
committed
Enable GPU support. Closes #6.
1 parent 9e377e1 commit 7867372

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchkde/modules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
self.algorithm = algorithm
6565
self.is_fitted = False
6666
self.n_samples = None
67+
self.device = None
6768
self.n_features = None
6869
self.data = None
6970

@@ -104,14 +105,17 @@ def fit(self,
104105
self.tree_.build(X)
105106
self.bandwidth = compute_bandwidth(X, self.bandwidth)
106107
self.kernel_module.bandwidth = self.bandwidth
108+
self.device = X.device
107109
if sample_weight is None:
108-
self.sample_weight = torch.ones(self.n_samples)
110+
self.sample_weight = torch.ones(self.n_samples).to(self.device)
109111
else:
110112
assert sample_weight.shape[0] == self.n_samples, "Sample weights must have the same length as the data."
113+
assert sample_weight.device == self.device, "Sample weights must be on the same device as the data."
111114
assert sample_weight.dim() == 1, "Sample weights must be one-dimensional."
112115
assert sample_weight.min() >= 0, "Sample weights must be non-negative."
113116
self.sample_weight = sample_weight
114117
self.is_fitted = True
118+
self.device = X.device
115119

116120
return self
117121

@@ -134,6 +138,7 @@ def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
134138
data.
135139
"""
136140
assert self.is_fitted, "Model must be fitted before scoring samples."
141+
assert X.device == self.device, "Device of the query data must be on the same device as the data for fitting the estimator."
137142

138143
n_samples = X.shape[0]
139144
# Compute log-density estimation with a kernel function

0 commit comments

Comments
 (0)