7
7
8
8
9
9
@njit
10
- def primal (alpha , y , X , w ):
10
+ def primal (alpha , y , X , w , weights ):
11
11
r = y - X @ w
12
12
p_obj = (r @ r ) / (2 * len (y ))
13
- return p_obj + alpha * np .sum (np .abs (w ))
13
+ return p_obj + alpha * np .sum (np .abs (w * weights ))
14
14
15
15
16
16
@njit
17
- def primal_grp (alpha , y , X , w , grp_ptr , grp_indices ):
17
+ def primal_grp (alpha , y , X , w , grp_ptr , grp_indices , weights ):
18
18
r = y - X @ w
19
19
p_obj = (r @ r ) / (2 * len (y ))
20
20
for g in range (len (grp_ptr ) - 1 ):
21
21
w_g = w [grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
22
- p_obj += alpha * norm (w_g , ord = 2 )
22
+ p_obj += alpha * norm (w_g * weights [ g ] , ord = 2 )
23
23
return p_obj
24
24
25
25
26
- def gram_lasso (X , y , alpha , max_iter , tol , check_freq = 10 ):
26
+ def gram_lasso (X , y , alpha , max_iter , tol , w_init = None , weights = None , check_freq = 10 ):
27
27
p_obj_prev = np .inf
28
28
n_features = X .shape [1 ]
29
29
grads = X .T @ y / len (y )
30
30
G = X .T @ X
31
31
lipschitz = np .zeros (n_features , dtype = X .dtype )
32
32
for j in range (n_features ):
33
33
lipschitz [j ] = (X [:, j ] ** 2 ).sum () / len (y )
34
- w = np .zeros (n_features )
34
+ w = w_init if w_init is not None else np .zeros (n_features )
35
+ weights = weights if weights is not None else np .ones (n_features )
35
36
# CD
36
37
for n_iter in range (max_iter ):
37
- cd_epoch (X , G , grads , w , alpha , lipschitz )
38
+ cd_epoch (X , G , grads , w , alpha , lipschitz , weights )
38
39
if n_iter % check_freq == 0 :
39
- p_obj = primal (alpha , y , X , w )
40
+ p_obj = primal (alpha , y , X , w , weights )
40
41
if p_obj_prev - p_obj < tol :
41
42
print ("Convergence reached!" )
42
43
break
@@ -45,7 +46,8 @@ def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
45
46
return w
46
47
47
48
48
- def gram_group_lasso (X , y , alpha , groups , max_iter , tol , check_freq = 50 ):
49
+ def gram_group_lasso (X , y , alpha , groups , max_iter , tol , w_init = None , weights = None ,
50
+ check_freq = 50 ):
49
51
p_obj_prev = np .inf
50
52
n_features = X .shape [1 ]
51
53
grp_ptr , grp_indices = _grp_converter (groups , X .shape [1 ])
@@ -56,12 +58,13 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
56
58
for g in range (n_groups ):
57
59
X_g = X [:, grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
58
60
lipschitz [g ] = norm (X_g , ord = 2 ) ** 2 / len (y )
59
- w = np .zeros (n_features )
61
+ w = w_init if w_init is not None else np .zeros (n_features )
62
+ weights = weights if weights is not None else np .ones (n_groups )
60
63
# BCD
61
64
for n_iter in range (max_iter ):
62
- bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr )
65
+ bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr , weights )
63
66
if n_iter % check_freq == 0 :
64
- p_obj = primal_grp (alpha , y , X , w , grp_ptr , grp_indices )
67
+ p_obj = primal_grp (alpha , y , X , w , grp_ptr , grp_indices , weights )
65
68
if p_obj_prev - p_obj < tol :
66
69
print ("Convergence reached!" )
67
70
break
@@ -71,26 +74,27 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
71
74
72
75
73
76
@njit
74
- def cd_epoch (X , G , grads , w , alpha , lipschitz ):
77
+ def cd_epoch (X , G , grads , w , alpha , lipschitz , weights ):
75
78
n_features = X .shape [1 ]
76
79
for j in range (n_features ):
77
- if lipschitz [j ] == 0. :
80
+ if lipschitz [j ] == 0. or weights [ j ] == np . inf :
78
81
continue
79
82
old_w_j = w [j ]
80
- w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ])
83
+ w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [ j ] )
81
84
if old_w_j != w [j ]:
82
85
grads += G [j , :] * (old_w_j - w [j ]) / len (X )
83
86
84
87
85
88
@njit
86
- def bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr ):
89
+ def bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr , weights ):
87
90
n_groups = len (grp_ptr ) - 1
88
91
for g in range (n_groups ):
89
- if lipschitz [g ] == 0. :
92
+ if lipschitz [g ] == 0. and weights [ g ] == np . inf :
90
93
continue
91
94
idx = grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]
92
95
old_w_g = w [idx ].copy ()
93
- w [idx ] = BST (w [idx ] + grads [idx ] / lipschitz [g ], alpha / lipschitz [g ])
96
+ w [idx ] = BST (w [idx ] + grads [idx ] / lipschitz [g ], alpha / lipschitz [g ]
97
+ * weights [g ])
94
98
diff = old_w_g - w [idx ]
95
99
if np .any (diff != 0. ):
96
100
grads += diff @ G [idx , :] / len (X )
0 commit comments