Skip to content

Commit 777cf60

Browse files
committed
import other modules
1 parent 241b789 commit 777cf60

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

rehline/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# Import from internal C++ module
22
from ._internal import rehline_internal, rehline_result
33

4+
from .rehloss import ReHLoss, PQLoss
5+
from .main import ReHLine, ReHLine_solver
6+
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
""" ReHLine: Regularized Composite ReHU/ReLU Loss Minimization """
1+
""" ReHLine: Regularized Composite ReLU-ReHU Loss Minimization with Linear Computation and Linear Convergence """
22

33
# Authors: Ben Dai <[email protected]>
4-
# C++ support by Yixuan Qiu <[email protected]>
4+
# C++ support by Yixuan Qiu <[email protected]>
55

66
# License: MIT License
77

88
import numpy as np
99
from sklearn.base import BaseEstimator
10-
import rehline
1110
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
12-
import base
11+
from .base import relu, rehu
12+
from ._internal import rehline_internal, rehline_result
1313

1414
def ReHLine_solver(X, U, V,
1515
Tau=np.empty(shape=(0, 0)),
1616
S=np.empty(shape=(0, 0)), T=np.empty(shape=(0, 0)),
1717
A=np.empty(shape=(0, 0)), b=np.empty(shape=(0)),
1818
max_iter=1000, tol=1e-4, shrink=True, verbose=True):
19-
result = rehline.rehline_result()
20-
rehline.rehline_internal(result, X, A, b, U, V, S, T, Tau, max_iter, tol, shrink, verbose)
19+
result = rehline_result()
20+
rehline_internal(result, X, A, b, U, V, S, T, Tau, max_iter, tol, shrink, verbose)
2121
return result
2222

2323
class ReHLine(BaseEstimator):
@@ -193,7 +193,7 @@ def call_ReLHLoss(self, input):
193193
relu_input = (self.U.T * input[:,np.newaxis]).T + self.V
194194
if self.H > 0:
195195
rehu_input = (self.S.T * input[:,np.newaxis]).T + self.T
196-
return np.sum(base.relu(relu_input), 0) + np.sum(base.rehu(rehu_input), 0)
196+
return np.sum(relu(relu_input), 0) + np.sum(rehu(rehu_input), 0)
197197

198198

199199
def fit(self, X, sample_weight=None):

rehline/rehloss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
""" ReHLoss: Convert a piecewise quadratic loss function to a ReHLoss. """
22

33
# Authors: Ben Dai <[email protected]>
4-
# Yixuan Qiu <[email protected]>
4+
# Yixuan Qiu <[email protected]>
55

66
# License: MIT License
77

88

99
import numpy as np
10-
import base
10+
from .base import relu, rehu, _check_relu
1111

1212
class ReHLoss(object):
1313
"""
@@ -68,14 +68,14 @@ def __call__(self, x):
6868
if (self.L > 0) and (self.H > 0):
6969
assert self.relu_coef.shape[1] == self.rehu_coef.shape[1], "n_samples for `relu_coef` and `rehu_coef` should be the same shape!"
7070

71-
base._check_relu(self.relu_coef, self.relu_intercept)
72-
base._check_rehu(self.rehu_coef, self.rehu_intercept, self.rehu_cut)
71+
_check_relu(self.relu_coef, self.relu_intercept)
72+
_check_rehu(self.rehu_coef, self.rehu_intercept, self.rehu_cut)
7373

7474
self.L, self.H, self.n = self.relu_coef.shape[0], self.rehu_coef.shape[0], self.relu_coef.shape[1]
7575
relu_input = (self.relu_coef.T * x[:,np.newaxis]).T + self.relu_intercept
7676
rehu_input = (self.rehu_coef.T * x[:,np.newaxis]).T + self.rehu_intercept
7777

78-
return np.sum(base.relu(relu_input), 0) + np.sum(base.rehu(rehu_input), 0)
78+
return np.sum(relu(relu_input), 0) + np.sum(rehu(rehu_input), 0)
7979

8080

8181
class PQLoss(object):

0 commit comments

Comments
 (0)