Skip to content

Commit 1674bb9

Browse files
author
KulikovNikita
authored
Basic statistics algorithm interface in oneDAL (#1205)
1 parent a55e393 commit 1674bb9

File tree

19 files changed

+648
-41
lines changed

19 files changed

+648
-41
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# ===============================================================================
2+
# Copyright 2023 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ===============================================================================
16+
17+
import numpy as np
18+
from mpi4py import MPI
19+
20+
from dpctl import SyclQueue
21+
from sklearnex.spmd.basic_statistics import BasicStatistics as BasicStatisticsSpmd
22+
23+
24+
def generate_data(par, size, seed=777):
25+
ns, nf = par['ns'], par['nf']
26+
27+
data_blocks, weight_blocks = [], []
28+
rng = np.random.default_rng(seed)
29+
30+
for b in range(size):
31+
data = rng.uniform(b, (b + 1) * (b + 1),
32+
size=(ns, nf))
33+
weights = rng.uniform(1, (b + 1), size=ns)
34+
weight_blocks.append(weights)
35+
data_blocks.append(data)
36+
37+
data = np.concatenate(data_blocks, axis=0)
38+
weights = np.concatenate(weight_blocks)
39+
40+
return (data, weights)
41+
42+
43+
q = SyclQueue("gpu")
44+
45+
comm = MPI.COMM_WORLD
46+
rank = comm.Get_rank()
47+
size = comm.Get_size()
48+
49+
params_spmd = {'ns': 19, 'nf': 31}
50+
51+
data, weights = generate_data(params_spmd, size)
52+
weighted_data = np.diag(weights) @ data
53+
54+
gtr_mean = np.mean(weighted_data, axis=0)
55+
gtr_std = np.std(weighted_data, axis=0)
56+
57+
bss = BasicStatisticsSpmd(["mean", "standard_deviation"])
58+
res = bss.compute(data, weights, queue=q)
59+
60+
print(f"Computed mean on rank {rank}:\n", res["mean"])
61+
print(f"Computed std on rank {rank}:\n", res["standard_deviation"])
Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#===============================================================================
1+
# ===============================================================================
22
# Copyright 2023 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,54 +12,57 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
#===============================================================================
15+
# ===============================================================================
1616

1717
import numpy as np
18+
from warnings import warn
19+
1820
from mpi4py import MPI
19-
import dpctl
20-
from numpy.testing import assert_allclose
21-
from onedal.spmd.linear_model import LinearRegression as LinRegSpmd
21+
from dpctl import SyclQueue
22+
from sklearnex.spmd.linear_model import LinearRegression
2223

2324

24-
def generate_X_y(par, coef_seed, data_seed):
25-
ns, nf, nr = par['ns'], par['nf'], par['nr']
25+
def generate_X_y(ns, data_seed):
26+
nf, nr = 129, 131
2627

27-
crng = np.random.default_rng(coef_seed)
28+
crng = np.random.default_rng(777)
2829
coef = crng.uniform(-4, 1, size=(nr, nf)).T
2930
intp = crng.uniform(-1, 9, size=(nr, ))
3031

3132
drng = np.random.default_rng(data_seed)
3233
data = drng.uniform(-7, 7, size=(ns, nf))
3334
resp = data @ coef + intp[np.newaxis, :]
3435

35-
return data, resp, coef, intp
36+
return data, resp
37+
38+
39+
def get_train_data(rank):
40+
return generate_X_y(101, rank)
41+
3642

43+
def get_test_data(rank):
44+
return generate_X_y(1024, rank)
3745

38-
if __name__ == "__main__":
39-
q = dpctl.SyclQueue("gpu")
4046

41-
comm = MPI.COMM_WORLD
42-
rank = comm.Get_rank()
43-
size = comm.Get_size()
47+
comm = MPI.COMM_WORLD
48+
rank = comm.Get_rank()
49+
size = comm.Get_size()
4450

45-
params_spmd = {'ns': 15, 'nf': 21, 'nr': 23}
46-
params_grtr = {'ns': 77, 'nf': 21, 'nr': 23}
51+
if size < 2:
52+
warn("This example was intentionally "
53+
"designed to run in distributed mode only", RuntimeWarning)
4754

48-
Xsp, ysp, csp, isp = generate_X_y(params_spmd, size, size + rank - 1)
49-
Xgt, ygt, cgt, igt = generate_X_y(params_grtr, size, size + rank + 1)
55+
X, y = get_train_data(rank)
5056

51-
assert_allclose(csp, cgt)
52-
assert_allclose(isp, igt)
57+
queue = SyclQueue("gpu")
5358

54-
lrsp = LinRegSpmd(copy_X=True, fit_intercept=True)
55-
lrsp.fit(Xsp, ysp, queue=q)
59+
model = LinearRegression().fit(X, y, queue)
5660

57-
assert_allclose(lrsp.coef_, csp.T)
58-
assert_allclose(lrsp.intercept_, isp)
61+
print(f"Coefficients on rank {rank}:\n", model.coef_)
62+
print(f"Intercept on rank {rank}:\n", model.intercept_)
5963

60-
ypr = lrsp.predict(Xgt, queue=q)
64+
X_test, _ = get_test_data(rank)
6165

62-
assert_allclose(ypr, ygt)
66+
result = model.predict(X_test, queue)
6367

64-
print("Groundtruth responses:\n", ygt)
65-
print("Computed responses:\n", ypr)
68+
print(f"Result on rank {rank}:\n", result)

onedal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
__all__.append('spmd')
4747

4848
if daal_check_version((2023, 'P', 100)):
49-
__all__.append('linear_model')
49+
__all__ += ['basic_statistics', 'linear_model']
5050

5151
if _is_dpc_backend:
52-
__all__.append('spmd.linear_model')
52+
__all__ += ['spmd.basic_statistics', 'spmd.linear_model']
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#===============================================================================
2+
# Copyright 2023 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#===============================================================================
16+
17+
from .basic_statistics import BasicStatistics
18+
19+
__all__ = ['BasicStatistics']
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "oneapi/dal/algo/basic_statistics.hpp"
18+
19+
#include "onedal/common.hpp"
20+
#include "onedal/version.hpp"
21+
22+
#include <string>
23+
#include <regex>
24+
#include <map>
25+
26+
namespace py = pybind11;
27+
28+
namespace oneapi::dal::python {
29+
30+
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100
31+
32+
namespace basic_statistics {
33+
34+
template <typename Task, typename Ops>
35+
struct method2t {
36+
method2t(const Task& task, const Ops& ops) : ops(ops) {}
37+
38+
template <typename Float>
39+
auto operator()(const py::dict& params) {
40+
using namespace dal::basic_statistics;
41+
42+
const auto method = params["method"].cast<std::string>();
43+
ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense);
44+
ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default);
45+
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method);
46+
}
47+
48+
Ops ops;
49+
};
50+
51+
#define RESULT_OPTION(option) { #option, dal::basic_statistics::result_options::option }
52+
53+
const std::map<std::string, dal::basic_statistics::result_option_id> result_option_registry {
54+
RESULT_OPTION(min), RESULT_OPTION(max), RESULT_OPTION(sum), RESULT_OPTION(mean),
55+
RESULT_OPTION(variance), RESULT_OPTION(variation), RESULT_OPTION(sum_squares),
56+
RESULT_OPTION(standard_deviation), RESULT_OPTION(sum_squares_centered),
57+
RESULT_OPTION(second_order_raw_moment)
58+
};
59+
60+
#undef RESULT_OPTION
61+
62+
auto get_onedal_result_options(const py::dict& params) {
63+
using namespace dal::basic_statistics;
64+
65+
auto result_option = params["result_option"].cast<std::string>();
66+
result_option_id onedal_options;
67+
68+
try {
69+
std::regex re("\\w+");
70+
const std::sregex_iterator last{};
71+
const std::sregex_iterator first( //
72+
result_option.begin(),
73+
result_option.end(),
74+
re);
75+
76+
for (std::sregex_iterator it = first; it != last; ++it) {
77+
const auto str = it->str();
78+
const auto match = result_option_registry.find(str);
79+
if (match == result_option_registry.cend()) {
80+
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(result_option);
81+
} else {
82+
onedal_options = onedal_options | match->second;
83+
}
84+
}
85+
}
86+
catch (std::regex_error& e) {
87+
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(result_option);
88+
}
89+
90+
return onedal_options;
91+
}
92+
93+
struct params2desc {
94+
template <typename Float, typename Method, typename Task>
95+
auto operator()(const py::dict& params) {
96+
auto desc = dal::basic_statistics::descriptor<Float,
97+
dal::basic_statistics::method::dense, dal::basic_statistics::task::compute>()
98+
.set_result_options(get_onedal_result_options(params));
99+
return desc;
100+
}
101+
};
102+
103+
template <typename Policy, typename Task>
104+
struct init_compute_ops_dispatcher {};
105+
106+
template <typename Policy>
107+
struct init_compute_ops_dispatcher<Policy, dal::basic_statistics::task::compute> {
108+
void operator()(py::module_& m) {
109+
using Task = dal::basic_statistics::task::compute;
110+
m.def("train",
111+
[](const Policy& policy,
112+
const py::dict& params,
113+
const table& data,
114+
const table& weights) {
115+
using namespace dal::basic_statistics;
116+
using input_t = compute_input<Task>;
117+
118+
compute_ops ops(policy, input_t{ data, weights }, params2desc{});
119+
return fptype2t{ method2t{ Task{}, ops } }(params);
120+
});
121+
}
122+
};
123+
124+
template <typename Policy, typename Task>
125+
void init_compute_ops(py::module& m) {
126+
init_compute_ops_dispatcher<Policy, Task>{}(m);
127+
}
128+
129+
template <typename Task>
130+
void init_compute_result(py::module_& m) {
131+
using namespace dal::basic_statistics;
132+
using result_t = compute_result<Task>;
133+
134+
auto cls = py::class_<result_t>(m, "compute_result")
135+
.def(py::init())
136+
.DEF_ONEDAL_PY_PROPERTY(min, result_t)
137+
.DEF_ONEDAL_PY_PROPERTY(max, result_t)
138+
.DEF_ONEDAL_PY_PROPERTY(sum, result_t)
139+
.DEF_ONEDAL_PY_PROPERTY(mean, result_t)
140+
.DEF_ONEDAL_PY_PROPERTY(variance, result_t)
141+
.DEF_ONEDAL_PY_PROPERTY(variation, result_t)
142+
.DEF_ONEDAL_PY_PROPERTY(sum_squares, result_t)
143+
.DEF_ONEDAL_PY_PROPERTY(standard_deviation, result_t)
144+
.DEF_ONEDAL_PY_PROPERTY(sum_squares_centered, result_t)
145+
.DEF_ONEDAL_PY_PROPERTY(second_order_raw_moment, result_t);
146+
}
147+
148+
ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_result);
149+
ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_ops);
150+
151+
} // namespace basic_statistics
152+
153+
ONEDAL_PY_INIT_MODULE(basic_statistics) {
154+
using namespace dal::detail;
155+
using namespace basic_statistics;
156+
using namespace dal::basic_statistics;
157+
158+
auto sub = m.def_submodule("basic_statistics");
159+
using task_list = types<dal::basic_statistics::task::compute>;
160+
161+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
162+
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list_spmd, task_list);
163+
#else // ONEDAL_DATA_PARALLEL_SPMD
164+
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list);
165+
#endif // ONEDAL_DATA_PARALLEL_SPMD
166+
167+
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list);
168+
}
169+
170+
ONEDAL_PY_TYPE2STR(dal::basic_statistics::task::compute, "compute");
171+
172+
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100
173+
174+
} // namespace oneapi::dal::python

0 commit comments

Comments
 (0)