|
33 | 33 | from .._device_offload import support_usm_ndarray |
34 | 34 |
|
35 | 35 | if sklearn_check_version('0.22'): |
36 | | - from sklearn.manifold._t_sne import _joint_probabilities, _joint_probabilities_nn |
| 36 | + from sklearn.manifold._t_sne import _joint_probabilities |
| 37 | + from sklearn.manifold._t_sne import _joint_probabilities_nn |
37 | 38 | else: |
38 | | - from sklearn.manifold.t_sne import _joint_probabilities, _joint_probabilities_nn |
| 39 | + from sklearn.manifold.t_sne import _joint_probabilities |
| 40 | + from sklearn.manifold.t_sne import _joint_probabilities_nn |
39 | 41 |
|
40 | 42 |
|
41 | 43 | class TSNE(BaseTSNE): |
@@ -98,8 +100,19 @@ def _daal_tsne(self, P, n_samples, X_embedded): |
98 | 100 | # * final optimization with momentum at 0.8 |
99 | 101 |
|
100 | 102 | # N, nnz, n_iter_without_progress, n_iter |
101 | | - size_iter = np.array([[n_samples], [P.nnz], [self.n_iter_without_progress], |
102 | | - [self.n_iter]], dtype=P.dtype) |
| 103 | + size_iter = [[n_samples], [P.nnz], |
| 104 | + [self.n_iter_without_progress], |
| 105 | + [self.n_iter]] |
| 106 | + |
| 107 | + # Pass params to daal4py backend |
| 108 | + if daal_check_version((2023, 'P', 1)): |
| 109 | + size_iter.extend( |
| 110 | + [[self._EXPLORATION_N_ITER], |
| 111 | + [self._N_ITER_CHECK]] |
| 112 | + ) |
| 113 | + |
| 114 | + size_iter = np.array(size_iter, dtype=P.dtype) |
| 115 | + |
103 | 116 | params = np.array([[self.early_exaggeration], [self._learning_rate], |
104 | 117 | [self.min_grad_norm], [self.angle]], dtype=P.dtype) |
105 | 118 | results = np.zeros((3, 1), dtype=P.dtype) # curIter, error, gradNorm |
@@ -164,17 +177,28 @@ def _fit(self, X, skip_num_points=0): |
164 | 177 | "or 'auto'.") |
165 | 178 |
|
166 | 179 | if hasattr(self, 'square_distances'): |
167 | | - if self.square_distances not in [True, 'legacy', 'deprecated']: |
168 | | - raise ValueError("'square_distances' must be True or 'legacy'.") |
169 | | - if self.metric != "euclidean" and self.square_distances is not True: |
170 | | - warnings.warn(("'square_distances' has been introduced in 0.24" |
171 | | - "to help phase out legacy squaring behavior. The " |
172 | | - "'legacy' setting will be removed in 0.26, and the " |
173 | | - "default setting will be changed to True. In 0.28, " |
174 | | - "'square_distances' will be removed altogether," |
175 | | - "and distances will be squared by default. Set " |
176 | | - "'square_distances'=True to silence this warning."), |
177 | | - FutureWarning) |
| 180 | + if sklearn_check_version("1.1"): |
| 181 | + if self.square_distances != "deprecated": |
| 182 | + warnings.warn( |
| 183 | + "The parameter `square_distances` has not effect " |
| 184 | + "and will be removed in version 1.3.", |
| 185 | + FutureWarning, |
| 186 | + ) |
| 187 | + else: |
| 188 | + if self.square_distances not in [True, "legacy"]: |
| 189 | + raise ValueError( |
| 190 | + "'square_distances' must be True or 'legacy'.") |
| 191 | + if self.metric != "euclidean" and self.square_distances is not True: |
| 192 | + warnings.warn( |
| 193 | + "'square_distances' has been introduced in 0.24 to help phase " |
| 194 | + "out legacy squaring behavior. The 'legacy' setting will be " |
| 195 | + "removed in 1.1 (renaming of 0.26), and the default setting " |
| 196 | + "will be changed to True. In 1.3, 'square_distances' will be " |
| 197 | + "removed altogether, and distances will be squared by " |
| 198 | + "default. Set 'square_distances'=True to silence this " |
| 199 | + "warning.", |
| 200 | + FutureWarning, |
| 201 | + ) |
178 | 202 |
|
179 | 203 | if self.method == 'barnes_hut': |
180 | 204 | if sklearn_check_version('0.23'): |
@@ -242,8 +266,12 @@ def _fit(self, X, skip_num_points=0): |
242 | 266 | distances = pairwise_distances(X, metric=self.metric, |
243 | 267 | squared=True) |
244 | 268 | else: |
| 269 | + metric_params_ = {} |
| 270 | + if sklearn_check_version('1.1'): |
| 271 | + metric_params_ = self.metric_params or {} |
245 | 272 | distances = pairwise_distances(X, metric=self.metric, |
246 | | - n_jobs=self.n_jobs) |
| 273 | + n_jobs=self.n_jobs, |
| 274 | + **metric_params_) |
247 | 275 |
|
248 | 276 | if np.any(distances < 0): |
249 | 277 | raise ValueError("All distances should be positive, the " |
@@ -272,12 +300,22 @@ def _fit(self, X, skip_num_points=0): |
272 | 300 | .format(n_neighbors)) |
273 | 301 |
|
274 | 302 | # Find the nearest neighbors for every point |
275 | | - knn = NearestNeighbors( |
276 | | - algorithm='auto', |
277 | | - n_jobs=self.n_jobs, |
278 | | - n_neighbors=n_neighbors, |
279 | | - metric=self.metric, |
280 | | - ) |
| 303 | + knn = None |
| 304 | + if sklearn_check_version("1.1"): |
| 305 | + knn = NearestNeighbors( |
| 306 | + algorithm='auto', |
| 307 | + n_jobs=self.n_jobs, |
| 308 | + n_neighbors=n_neighbors, |
| 309 | + metric=self.metric, |
| 310 | + metric_params=self.metric_params |
| 311 | + ) |
| 312 | + else: |
| 313 | + knn = NearestNeighbors( |
| 314 | + algorithm='auto', |
| 315 | + n_jobs=self.n_jobs, |
| 316 | + n_neighbors=n_neighbors, |
| 317 | + metric=self.metric |
| 318 | + ) |
281 | 319 | t0 = time() |
282 | 320 | knn.fit(X) |
283 | 321 | duration = time() - t0 |
@@ -336,11 +374,13 @@ def _fit(self, X, skip_num_points=0): |
336 | 374 | # Laurens van der Maaten, 2009. |
337 | 375 | degrees_of_freedom = max(self.n_components - 1, 1) |
338 | 376 |
|
339 | | - daal_ready = self.method == 'barnes_hut' and self.n_components == 2 and \ |
340 | | - self.verbose == 0 and daal_check_version((2021, 'P', 600)) |
| 377 | + daal_ready = self.method == 'barnes_hut' and \ |
| 378 | + self.n_components == 2 and self.verbose == 0 and \ |
| 379 | + daal_check_version((2021, 'P', 600)) |
341 | 380 |
|
342 | 381 | if daal_ready: |
343 | | - X_embedded = check_array(X_embedded, dtype=[np.float32, np.float64]) |
| 382 | + X_embedded = check_array( |
| 383 | + X_embedded, dtype=[np.float32, np.float64]) |
344 | 384 | return self._daal_tsne( |
345 | 385 | P, |
346 | 386 | n_samples, |
|
0 commit comments