Skip to content

Commit 324ffac

Browse files
committed
refactor: rewrite native extensions of decision tree in C
1 parent fb9ccfc commit 324ffac

File tree

6 files changed

+514
-591
lines changed

6 files changed

+514
-591
lines changed

rumale-tree/ext/rumale/tree/ext.c

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#include "ext.h"
2+
3+
DEF_ITER_FIND_SPLIT_CLS(dfloat)
4+
DEF_ITER_FIND_SPLIT_CLS(sfloat)
5+
6+
static VALUE find_split_params_cls(
7+
VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
8+
VALUE n_classes
9+
) {
10+
VALUE klass = rb_obj_class(features);
11+
ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { klass, 1 }, { numo_cInt32, 1 } };
12+
size_t out_shape[1] = { 4 };
13+
ndfunc_arg_out_t aout[1] = { { klass, 1, out_shape } };
14+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_find_split_cls;
15+
if (klass == numo_cSFloat) {
16+
iter_func = (na_iter_func_t)iter_sfloat_find_split_cls;
17+
}
18+
ndfunc_t ndf = { iter_func, NO_LOOP, 3, 1, ain, aout };
19+
split_cls_opt opts = { StringValueCStr(criterion), NUM2SIZET(n_classes), NUM2DBL(impurity) };
20+
VALUE res = na_ndloop3(&ndf, &opts, 3, order, features, labels);
21+
RB_GC_GUARD(criterion);
22+
return res;
23+
}
24+
25+
DEF_ITER_NODE_IMPURITY_CLS(dfloat)
26+
DEF_ITER_NODE_IMPURITY_CLS(sfloat)
27+
28+
static VALUE
29+
node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes, VALUE klass) {
30+
ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
31+
ndfunc_arg_out_t aout[1] = { { klass, 0 } };
32+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_node_impurity_cls;
33+
if (klass == numo_cSFloat) {
34+
iter_func = (na_iter_func_t)iter_sfloat_node_impurity_cls;
35+
}
36+
ndfunc_t ndf = { iter_func, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
37+
node_impurity_cls_opt opts = { StringValueCStr(criterion), NUM2SIZET(n_classes) };
38+
VALUE res = na_ndloop3(&ndf, &opts, 1, y);
39+
RB_GC_GUARD(criterion);
40+
return res;
41+
}
42+
43+
static void iter_check_same_label(na_loop_t const* lp) {
44+
const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
45+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
46+
VALUE* res = (VALUE*)NDL_PTR(lp, 1);
47+
*res = Qtrue;
48+
if (n_elements > 0) {
49+
int32_t label = y[0];
50+
for (size_t i = 0; i < n_elements; i++) {
51+
if (y[i] != label) {
52+
*res = Qfalse;
53+
break;
54+
}
55+
}
56+
}
57+
}
58+
59+
static VALUE check_same_label(VALUE self, VALUE y) {
60+
ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
61+
ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
62+
ndfunc_t ndf = {
63+
(na_iter_func_t)iter_check_same_label, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout
64+
};
65+
return na_ndloop(&ndf, 1, y);
66+
}
67+
68+
DEF_ITER_FIND_SPLIT_REG(dfloat)
69+
DEF_ITER_FIND_SPLIT_REG(sfloat)
70+
71+
static VALUE find_split_params_reg(
72+
VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets
73+
) {
74+
VALUE klass = rb_obj_class(features);
75+
ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { klass, 1 }, { numo_cDFloat, 2 } };
76+
size_t out_shape[1] = { 4 };
77+
ndfunc_arg_out_t aout[1] = { { klass, 1, out_shape } };
78+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_find_split_reg;
79+
if (klass == numo_cSFloat) {
80+
iter_func = (na_iter_func_t)iter_sfloat_find_split_reg;
81+
}
82+
ndfunc_t ndf = { iter_func, NO_LOOP, 3, 1, ain, aout };
83+
split_reg_opt opts = { StringValueCStr(criterion), NUM2DBL(impurity) };
84+
VALUE res = na_ndloop3(&ndf, &opts, 3, order, features, targets);
85+
RB_GC_GUARD(criterion);
86+
return res;
87+
}
88+
89+
DEF_ITER_NODE_IMPURITY_REG(dfloat)
90+
DEF_ITER_NODE_IMPURITY_REG(sfloat)
91+
92+
static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y, VALUE klass) {
93+
ndfunc_arg_in_t ain[1] = { { klass, 2 } };
94+
ndfunc_arg_out_t aout[1] = { { klass, 0 } };
95+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_node_impurity_reg;
96+
if (klass == numo_cSFloat) {
97+
iter_func = (na_iter_func_t)iter_sfloat_node_impurity_reg;
98+
}
99+
ndfunc_t ndf = { iter_func, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
100+
node_impurity_reg_opt opts = { StringValueCStr(criterion) };
101+
VALUE res = na_ndloop3(&ndf, &opts, 1, y);
102+
RB_GC_GUARD(criterion);
103+
return res;
104+
}
105+
106+
DEF_ITER_CHECK_SAME_VALUE(dfloat, DBL_EPSILON)
107+
DEF_ITER_CHECK_SAME_VALUE(sfloat, DBL_EPSILON)
108+
109+
static VALUE check_same_value(VALUE self, VALUE y) {
110+
VALUE klass = rb_obj_class(y);
111+
ndfunc_arg_in_t ain[1] = { { klass, 2 } };
112+
ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
113+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_check_same_value;
114+
if (klass == numo_cSFloat) {
115+
iter_func = (na_iter_func_t)iter_sfloat_check_same_value;
116+
}
117+
ndfunc_t ndf = { iter_func, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
118+
return na_ndloop(&ndf, 1, y);
119+
}
120+
121+
DEF_ITER_FIND_SPLIT_GREG(dfloat)
122+
DEF_ITER_FIND_SPLIT_GREG(sfloat)
123+
124+
static VALUE find_split_params_greg(
125+
VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians, VALUE sum_gradient,
126+
VALUE sum_hessian, VALUE reg_lambda
127+
) {
128+
VALUE klass = rb_obj_class(features);
129+
ndfunc_arg_in_t ain[4] = { { numo_cInt32, 1 }, { klass, 1 }, { klass, 1 }, { klass, 1 } };
130+
size_t out_shape[1] = { 2 };
131+
ndfunc_arg_out_t aout[1] = { { klass, 1, out_shape } };
132+
na_iter_func_t iter_func = (na_iter_func_t)iter_dfloat_find_split_greg;
133+
if (klass == numo_cSFloat) {
134+
iter_func = (na_iter_func_t)iter_sfloat_find_split_greg;
135+
}
136+
ndfunc_t ndf = { iter_func, NO_LOOP, 4, 1, ain, aout };
137+
split_greg_opt opts = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
138+
VALUE params = na_ndloop3(&ndf, &opts, 4, order, features, gradients, hessians);
139+
return params;
140+
}
141+
142+
void Init_ext(void) {
143+
VALUE rb_mRumale = rb_define_module("Rumale");
144+
VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
145+
VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
146+
rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
147+
rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 4);
148+
rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
149+
VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
150+
rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
151+
rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 3);
152+
rb_define_private_method(rb_mExtDTreeReg, "stop_growing?", check_same_value, 1);
153+
VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
154+
rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_greg, 7);
155+
}

rumale-tree/ext/rumale/tree/ext.cpp

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)