Skip to content

Commit cc77f01

Browse files
committed
Closes #5 (scoring is much quicker now)
1 parent 4a485fb commit cc77f01

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torchkde/modules.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,17 @@ def fit(self,
115115

116116
return self
117117

118-
def score_samples(self, X: torch.Tensor) -> torch.Tensor:
118+
def score_samples(self, X: torch.Tensor, batch_size: int = 128) -> torch.Tensor:
119119
"""Compute the log-likelihood of each sample under the model.
120120
121121
Parameters
122122
----------
123123
X : torch Tensor of shape (n_samples, n_features)
124124
An array of points to query. Last dimension should match dimension
125125
of training data (n_features).
126+
batch_size : int, default=64
127+
Number of samples to process in each batch.
128+
126129
Returns
127130
-------
128131
density : torch Tensor of shape (n_samples,)
@@ -131,16 +134,19 @@ def score_samples(self, X: torch.Tensor) -> torch.Tensor:
131134
data.
132135
"""
133136
assert self.is_fitted, "Model must be fitted before scoring samples."
134-
135-
X_neighbors = self.tree_.query(X, return_distance=False)
137+
138+
n_samples = X.shape[0]
136139
# Compute log-density estimation with a kernel function
137140
log_density = []
138141
# Compute normalization part from bandwidth matrix
139142
bw_norm = torch.sqrt(torch.det(self.bandwidth)) if check_if_mat(self.bandwidth) else self.bandwidth**(self.n_features/2)
140143
# looping to avoid memory issues
141-
for i, x in enumerate(X):
144+
for start in range(0, n_samples, batch_size):
145+
end = min(start + batch_size, n_samples)
146+
X_batch = X[start:end]
147+
X_neighbors = self.tree_.query(X_batch, return_distance=False)
142148
# Compute pairwise differences between the current point and neighbors
143-
differences = x - X_neighbors[i]
149+
differences = X_batch.unsqueeze(1) - X_neighbors
144150
# Apply the kernel function to each difference
145151
kernel_values = self.kernel_module(differences)
146152
# Sum kernel values and normalize
@@ -149,7 +155,8 @@ def score_samples(self, X: torch.Tensor) -> torch.Tensor:
149155
log_density.append(density.log())
150156

151157
# Convert the list of log-density values into a tensor
152-
log_density = torch.stack(log_density, dim=0)
158+
log_density = torch.cat(log_density, dim=0)
159+
153160
return log_density
154161

155162

@@ -195,5 +202,4 @@ def sample(self, n_samples: int = 1) -> torch.Tensor:
195202
X = self.bandwidth * torch.randn(n_samples, data.shape[1]) + data[idxs]
196203

197204
return ensure_two_dimensional(X)
198-
199205

0 commit comments

Comments
 (0)