Skip to content

Commit b6a2234

Browse files
authored
t-SNE patching (#778)
1 parent 2f48ace commit b6a2234

File tree

7 files changed

+116
-3
lines changed

7 files changed

+116
-3
lines changed

daal4py/sklearn/manifold/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
#===============================================================================
3-
# Copyright 2014 Intel Corporation
3+
# Copyright 2020 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.

daal4py/sklearn/manifold/_t_sne.py

100644100755
Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from time import time
2121
import numpy as np
2222
from scipy.sparse import issparse
23+
import daal4py
24+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2325

2426
from sklearn.manifold import TSNE as BaseTSNE
2527
from sklearn.decomposition import PCA
@@ -28,7 +30,6 @@
2830
from sklearn.utils import check_random_state, check_array
2931

3032
from ..neighbors import NearestNeighbors
31-
from .._utils import sklearn_check_version
3233
from .._device_offload import support_usm_ndarray
3334

3435
if sklearn_check_version('0.22'):
@@ -88,6 +89,48 @@ def fit(self, X, y=None):
8889
"""
8990
return super().fit(X, y)
9091

92+
def _daal_tsne(self, P, n_samples, X_embedded):
93+
"""Runs t-SNE."""
94+
# t-SNE minimizes the Kullback-Leiber divergence of the Gaussians P
95+
# and the Student's t-distributions Q. The optimization algorithm that
96+
# we use is batch gradient descent with two stages:
97+
# * initial optimization with early exaggeration and momentum at 0.5
98+
# * final optimization with momentum at 0.8
99+
100+
# N, nnz, n_iter_without_progress, n_iter
101+
size_iter = np.array([[n_samples], [P.nnz], [self.n_iter_without_progress],
102+
[self.n_iter]], dtype=P.dtype)
103+
params = np.array([[self.early_exaggeration], [self._learning_rate],
104+
[self.min_grad_norm], [self.angle]], dtype=P.dtype)
105+
results = np.zeros((3, 1), dtype=P.dtype) # curIter, error, gradNorm
106+
107+
if P.dtype == np.float64:
108+
daal4py.daal_tsne_gradient_descent(
109+
X_embedded,
110+
P,
111+
size_iter,
112+
params,
113+
results,
114+
0)
115+
elif P.dtype == np.float32:
116+
daal4py.daal_tsne_gradient_descent(
117+
X_embedded,
118+
P,
119+
size_iter,
120+
params,
121+
results,
122+
1)
123+
else:
124+
raise ValueError("unsupported dtype of 'P' matrix")
125+
126+
# Save the final number of iterations
127+
self.n_iter_ = int(results[0][0])
128+
129+
# Save Kullback-Leiber divergence
130+
self.kl_divergence_ = results[1][0]
131+
132+
return X_embedded
133+
91134
def _fit(self, X, skip_num_points=0):
92135
"""Private function to fit the model using X as training data."""
93136
if isinstance(self.init, str) and self.init == 'warn':
@@ -293,6 +336,16 @@ def _fit(self, X, skip_num_points=0):
293336
# Laurens van der Maaten, 2009.
294337
degrees_of_freedom = max(self.n_components - 1, 1)
295338

339+
daal_ready = self.method == 'barnes_hut' and self.n_components == 2 and \
340+
self.verbose == 0 and daal_check_version((2021, 'P', 600))
341+
342+
if daal_ready:
343+
X_embedded = check_array(X_embedded, dtype=[np.float32, np.float64])
344+
return self._daal_tsne(
345+
P,
346+
n_samples,
347+
X_embedded=X_embedded
348+
)
296349
return self._tsne(
297350
P,
298351
degrees_of_freedom,

deselected_tests.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ deselected_tests:
209209
# Some sklearnex docstrings differ from scikit-learn.
210210
- tests/test_docstrings.py >=1.0.2
211211

212+
# Accuracy of sklearnex and sklearn may differ due to different approaches
213+
- manifold/tests/test_t_sne.py::test_preserve_trustworthiness_approximately_with_precomputed_distances
214+
- manifold/tests/test_t_sne.py::test_bh_match_exact
215+
- manifold/tests/test_t_sne.py::test_uniform_grid[barnes_hut]
216+
- manifold/tests/test_t_sne.py::test_sparse_precomputed_distance
217+
- manifold/tests/test_t_sne.py::test_tsne_different_square_distances >=0.24
218+
212219
# Temporary deselected up to 2021.6 release. Need to fix
213220
- ensemble/tests/test_bagging.py::test_classification
214221

generator/wrapper_gen.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,20 @@ def daal_generate_shuffled_indices(idx, random_state):
260260
c_generate_shuffled_indices(data_or_file(<PyObject*>idx),
261261
data_or_file(<PyObject*>random_state))
262262
263+
264+
cdef extern from "daal4py.h":
265+
cdef void c_tsne_gradient_descent(data_or_file & init, data_or_file & p,
266+
data_or_file & size_iter, data_or_file & params,
267+
data_or_file & results, char dtype) except +
268+
269+
270+
def daal_tsne_gradient_descent(init, p, size_iter, params, results, dtype=0):
271+
c_tsne_gradient_descent(data_or_file(<PyObject*>init), data_or_file(<PyObject*>p),
272+
data_or_file(<PyObject*>size_iter),
273+
data_or_file(<PyObject*>params),
274+
data_or_file(<PyObject*>results), dtype)
275+
276+
263277
def _execute_with_context(func):
264278
def exec_func(*args, **keyArgs):
265279
if 'daal4py.oneapi' in sys.modules:

generator/wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def wrap_algo(algo, ver):
4141
'algorithms::classification::training',
4242
'algorithms::tree_utils',
4343
'algorithms::tree_utils::classification',
44-
'algorithms::tree_utils::regression']):
44+
'algorithms::tree_utils::regression',
45+
'algorithms::internal']):
4546
return False
4647
# ignore unsupported algos
4748
if any(x in algo for x in ['quality_metric', '::interface']):

src/daal4py.cpp

100644100755
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,32 @@ void c_generate_shuffled_indices(data_or_file & idx, data_or_file & random_state
897897
#else
898898
#endif
899899
}
900+
901+
void c_tsne_gradient_descent(data_or_file & init, data_or_file & p, data_or_file & size_iter, data_or_file & params, data_or_file & results, char dtype)
902+
{
903+
#if __INTEL_DAAL__ == 2021 && INTEL_DAAL_VERSION >= 20210600
904+
auto initTable = get_table(init);
905+
auto pTable = get_table(p);
906+
auto sizeIterTable = get_table(size_iter);
907+
auto paramTable = get_table(params);
908+
auto resultTable = get_table(results);
909+
daal::data_management::CSRNumericTablePtr csrTable = daal::services::dynamicPointerCast<daal::data_management::CSRNumericTable, daal::data_management::NumericTable>(pTable);
910+
911+
if (csrTable)
912+
{
913+
switch (dtype)
914+
{
915+
case 0:
916+
daal::algorithms::internal::tsneGradientDescent<int, double>(initTable, csrTable, sizeIterTable, paramTable, resultTable);
917+
break;
918+
case 1:
919+
daal::algorithms::internal::tsneGradientDescent<int, float>(initTable, csrTable, sizeIterTable, paramTable, resultTable);
920+
break;
921+
default: throw std::invalid_argument("Invalid data type specified.");
922+
}
923+
}
924+
else
925+
PyErr_SetString(PyExc_RuntimeError, "Unexpected table type");
926+
#else
927+
#endif
928+
}

src/daal4py.h

100644100755
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ using daal::services::LibraryVersionInfo;
5555
#if __INTEL_DAAL__ == 2021 && INTEL_DAAL_VERSION >= 20210200
5656
#include "data_management/data/internal/roc_auc_score.h"
5757
#endif
58+
#if __INTEL_DAAL__ == 2021 && INTEL_DAAL_VERSION >= 20210600
59+
#include "algorithms/tsne/tsne_gradient_descent.h"
60+
#endif
5861

5962

6063
extern "C" {
@@ -342,4 +345,10 @@ extern "C" {
342345
void c_generate_shuffled_indices(data_or_file & idx, data_or_file & random_state);
343346
}
344347

348+
extern "C"
349+
{
350+
void c_tsne_gradient_descent(data_or_file & init, data_or_file & p, data_or_file & size_iter,
351+
data_or_file & params, data_or_file & results, char dtype);
352+
}
353+
345354
#endif // _HLAPI_H_INCLUDED_

0 commit comments

Comments
 (0)