@@ -115,14 +115,17 @@ def fit(self,
115115
116116 return self
117117
118- def score_samples (self , X : torch .Tensor ) -> torch .Tensor :
118+ def score_samples (self , X : torch .Tensor , batch_size : int = 128 ) -> torch .Tensor :
119119 """Compute the log-likelihood of each sample under the model.
120120
121121 Parameters
122122 ----------
123123 X : torch Tensor of shape (n_samples, n_features)
124124 An array of points to query. Last dimension should match dimension
125125 of training data (n_features).
126+ batch_size : int, default=64
127+ Number of samples to process in each batch.
128+
126129 Returns
127130 -------
128131 density : torch Tensor of shape (n_samples,)
@@ -131,16 +134,19 @@ def score_samples(self, X: torch.Tensor) -> torch.Tensor:
131134 data.
132135 """
133136 assert self .is_fitted , "Model must be fitted before scoring samples."
134-
135- X_neighbors = self . tree_ . query ( X , return_distance = False )
137+
138+ n_samples = X . shape [ 0 ]
136139 # Compute log-density estimation with a kernel function
137140 log_density = []
138141 # Compute normalization part from bandwidth matrix
139142 bw_norm = torch .sqrt (torch .det (self .bandwidth )) if check_if_mat (self .bandwidth ) else self .bandwidth ** (self .n_features / 2 )
140143 # looping to avoid memory issues
141- for i , x in enumerate (X ):
144+ for start in range (0 , n_samples , batch_size ):
145+ end = min (start + batch_size , n_samples )
146+ X_batch = X [start :end ]
147+ X_neighbors = self .tree_ .query (X_batch , return_distance = False )
142148 # Compute pairwise differences between the current point and neighbors
143- differences = x - X_neighbors [ i ]
149+ differences = X_batch . unsqueeze ( 1 ) - X_neighbors
144150 # Apply the kernel function to each difference
145151 kernel_values = self .kernel_module (differences )
146152 # Sum kernel values and normalize
@@ -149,7 +155,8 @@ def score_samples(self, X: torch.Tensor) -> torch.Tensor:
149155 log_density .append (density .log ())
150156
151157 # Convert the list of log-density values into a tensor
152- log_density = torch .stack (log_density , dim = 0 )
158+ log_density = torch .cat (log_density , dim = 0 )
159+
153160 return log_density
154161
155162
@@ -195,5 +202,4 @@ def sample(self, n_samples: int = 1) -> torch.Tensor:
195202 X = self .bandwidth * torch .randn (n_samples , data .shape [1 ]) + data [idxs ]
196203
197204 return ensure_two_dimensional (X )
198-
199205
0 commit comments