Skip to content

Commit 5696d40

Browse files
committed
Update estimators to use barebones solver, make it at least as fast as sklearn
1 parent c199221 commit 5696d40

File tree

3 files changed

+166
-78
lines changed

3 files changed

+166
-78
lines changed

skglm/estimators.py

Lines changed: 117 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
import numpy as np
5+
from scipy.linalg import pinvh
56
from scipy.sparse import issparse
67
from scipy.special import expit
78
from numbers import Integral, Real
@@ -25,6 +26,10 @@
2526
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2627
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
2728
from skglm.utils.data import grp_converter
29+
from skglm.utils.prox_funcs import ST_vec
30+
31+
from numba import njit
32+
from skglm.solvers.gram_cd import barebones_cd_gram
2833

2934

3035
def _glm_fit(X, y, model, datafit, penalty, solver):
@@ -1687,22 +1692,32 @@ def fit(self, X, y):
16871692

16881693
return _glm_fit(X, y, self, quad_group, group_penalty, solver)
16891694

1695+
####################
1696+
# WIP Graphical Lasso
1697+
####################
1698+
16901699

16911700
class GraphicalLasso():
1701+
""" A first-order BCD Graphical Lasso solver implementing the GLasso algorithm
1702+
described in Friedman et al., 2008 and the P-GLasso algorithm described in
1703+
Mazumder et al., 2012."""
1704+
16921705
def __init__(self,
16931706
alpha=1.,
16941707
weights=None,
1695-
algo="banerjee",
1696-
max_iter=1000,
1708+
algo="dual",
1709+
max_iter=100,
16971710
tol=1e-8,
16981711
warm_start=False,
1712+
inner_tol=1e-4,
16991713
):
17001714
self.alpha = alpha
17011715
self.weights = weights
17021716
self.algo = algo
17031717
self.max_iter = max_iter
17041718
self.tol = tol
17051719
self.warm_start = warm_start
1720+
self.inner_tol = inner_tol
17061721

17071722
def fit(self, S):
17081723
p = S.shape[-1]
@@ -1716,90 +1731,110 @@ def fit(self, S):
17161731
raise ValueError("Weights should be symmetric.")
17171732

17181733
if self.warm_start and hasattr(self, "precision_"):
1719-
if self.algo == "banerjee":
1734+
if self.algo == "dual":
17201735
raise ValueError(
1721-
"Banerjee does not support warm start for now.")
1736+
"dual does not support warm start for now.")
17221737
Theta = self.precision_
17231738
W = self.covariance_
1724-
else:
1725-
W = S.copy() # + alpha*np.eye(p)
1726-
Theta = np.linalg.pinv(W, hermitian=True)
17271739

1728-
datafit = compiled_clone(QuadraticHessian())
1729-
penalty = compiled_clone(
1730-
WeightedL1(alpha=self.alpha, weights=Weights[0, :-1]))
1740+
else:
1741+
W = S.copy()
1742+
W *= 0.95
1743+
diagonal = S.flat[:: p + 1]
1744+
W.flat[:: p + 1] = diagonal
1745+
Theta = pinvh(W)
17311746

1732-
solver = AndersonCD(warm_start=True,
1733-
fit_intercept=False,
1734-
ws_strategy="fixpoint")
1747+
W_11 = np.copy(W[1:, 1:], order="C")
1748+
eps = np.finfo(np.float64).eps
1749+
it = 0
1750+
Theta_old = Theta.copy()
17351751

17361752
for it in range(self.max_iter):
17371753
Theta_old = Theta.copy()
1754+
17381755
for col in range(p):
1739-
indices_minus_col = np.concatenate(
1740-
[indices[:col], indices[col + 1:]])
1741-
_11 = indices_minus_col[:, None], indices_minus_col[None]
1742-
_12 = indices_minus_col, col
1743-
_21 = col, indices_minus_col
1744-
_22 = col, col
1745-
1746-
W_11 = W[_11]
1747-
w_12 = W[_12]
1748-
w_22 = W[_22]
1749-
s_12 = S[_12]
1750-
s_22 = S[_22]
1751-
1752-
penalty.weights = Weights[_12]
1753-
1754-
if self.algo == "banerjee":
1755-
w_init = Theta[_12]/Theta[_22]
1756-
Xw_init = W_11 @ w_init
1756+
if self.algo == "primal":
1757+
indices_minus_col = np.concatenate(
1758+
[indices[:col], indices[col + 1:]])
1759+
_11 = indices_minus_col[:, None], indices_minus_col[None]
1760+
_12 = indices_minus_col, col
1761+
_21 = col, indices_minus_col
1762+
_22 = col, col
1763+
1764+
elif self.algo == "dual":
1765+
if col > 0:
1766+
di = col - 1
1767+
W_11[di] = W[di][indices != col]
1768+
W_11[:, di] = W[:, di][indices != col]
1769+
else:
1770+
W_11[:] = W[1:, 1:]
1771+
1772+
s_12 = S[col, indices != col]
1773+
1774+
if self.algo == "dual":
1775+
beta_init = (Theta[indices != col, col] /
1776+
(Theta[col, col] + 1000 * eps))
17571777
Q = W_11
1758-
elif self.algo == "mazumder":
1759-
inv_Theta_11 = W_11 - np.outer(w_12, w_12)/w_22
1778+
1779+
elif self.algo == "primal":
1780+
inv_Theta_11 = (W[_11] -
1781+
np.outer(W[_12],
1782+
W[_12])/W[_22])
17601783
Q = inv_Theta_11
1761-
w_init = Theta[_12] * w_22
1762-
Xw_init = inv_Theta_11 @ w_init
1784+
beta_init = Theta[indices != col, col] * S[col, col]
17631785
else:
17641786
raise ValueError(f"Unsupported algo {self.algo}")
17651787

1766-
beta, _, _ = solver._solve(
1788+
beta = barebones_cd_gram(
17671789
Q,
17681790
s_12,
1769-
datafit,
1770-
penalty,
1771-
w_init=w_init,
1772-
Xw_init=Xw_init,
1791+
x=beta_init,
1792+
alpha=self.alpha,
1793+
weights=Weights[indices != col, col],
1794+
tol=self.inner_tol,
1795+
max_iter=self.max_iter,
17731796
)
17741797

1775-
if self.algo == "banerjee":
1776-
w_12 = -W_11 @ beta
1777-
W[_12] = w_12
1778-
W[_21] = w_12
1779-
Theta[_22] = 1/(s_22 + beta @ w_12)
1780-
Theta[_12] = beta*Theta[_22]
1781-
else: # mazumder
1782-
theta_12 = beta / s_22
1783-
theta_22 = 1/s_22 + theta_12 @ inv_Theta_11 @ theta_12
1784-
1785-
Theta[_12] = theta_12
1786-
Theta[_21] = theta_12
1787-
Theta[_22] = theta_22
1788-
1789-
w_22 = 1/(theta_22 - theta_12 @ inv_Theta_11 @ theta_12)
1790-
w_12 = -w_22*inv_Theta_11 @ theta_12
1791-
W_11 = inv_Theta_11 + np.outer(w_12, w_12)/w_22
1792-
W[_11] = W_11
1793-
W[_12] = w_12
1794-
W[_21] = w_12
1795-
W[_22] = w_22
1798+
if self.algo == "dual":
1799+
w_12 = -np.dot(W_11, beta)
1800+
W[col, indices != col] = w_12
1801+
W[indices != col, col] = w_12
1802+
1803+
Theta[col, col] = 1 / \
1804+
(W[col, col] + np.dot(beta, w_12))
1805+
Theta[indices != col, col] = beta*Theta[col, col]
1806+
Theta[col, indices != col] = beta*Theta[col, col]
1807+
1808+
else: # primal
1809+
Theta[indices != col, col] = beta / S[col, col]
1810+
Theta[col, indices != col] = beta / S[col, col]
1811+
Theta[col, col] = (1/S[col, col] +
1812+
Theta[col, indices != col] @
1813+
inv_Theta_11 @
1814+
Theta[indices != col, col])
1815+
W[col, col] = (1/(Theta[col, col] -
1816+
Theta[indices != col, col] @
1817+
inv_Theta_11 @
1818+
Theta[indices != col, col]))
1819+
W[indices != col, col] = (-W[col, col] *
1820+
inv_Theta_11 @
1821+
Theta[indices != col, col])
1822+
W[col, indices != col] = (-W[col, col] *
1823+
inv_Theta_11 @
1824+
Theta[indices != col, col])
1825+
# Maybe W_11 can be done smarter ?
1826+
W[_11] = (inv_Theta_11 +
1827+
np.outer(W[indices != col, col],
1828+
W[indices != col, col])/W[col, col])
17961829

17971830
if np.linalg.norm(Theta - Theta_old) < self.tol:
17981831
print(f"Weighted Glasso converged at CD epoch {it + 1}")
17991832
break
18001833
else:
1801-
print(f"Not converged at epoch {it + 1}, "
1802-
f"diff={np.linalg.norm(Theta - Theta_old):.2e}")
1834+
print(
1835+
f"Not converged at epoch {it + 1}, "
1836+
f"diff={np.linalg.norm(Theta - Theta_old):.2e}"
1837+
)
18031838
self.precision_, self.covariance_ = Theta, W
18041839
self.n_iter_ = it + 1
18051840

@@ -1810,33 +1845,47 @@ class AdaptiveGraphicalLasso():
18101845
def __init__(
18111846
self,
18121847
alpha=1.,
1848+
strategy="log",
18131849
n_reweights=5,
18141850
max_iter=1000,
18151851
tol=1e-8,
18161852
warm_start=False,
1817-
# verbose=False,
18181853
):
18191854
self.alpha = alpha
1855+
self.strategy = strategy
18201856
self.n_reweights = n_reweights
18211857
self.max_iter = max_iter
18221858
self.tol = tol
18231859
self.warm_start = warm_start
18241860

18251861
def fit(self, S):
18261862
glasso = GraphicalLasso(
1827-
alpha=self.alpha, algo="mazumder", max_iter=self.max_iter,
1828-
tol=self.tol, warm_start=True)
1863+
alpha=self.alpha,
1864+
algo="primal",
1865+
max_iter=self.max_iter,
1866+
tol=self.tol,
1867+
warm_start=True)
18291868
Weights = np.ones(S.shape)
18301869
self.n_iter_ = []
18311870
for it in range(self.n_reweights):
18321871
glasso.weights = Weights
18331872
glasso.fit(S)
18341873
Theta = glasso.precision_
1835-
Weights = 1/(np.abs(Theta) + 1e-10)
1874+
if self.strategy == "log":
1875+
Weights = 1/(np.abs(Theta) + 1e-10)
1876+
elif self.strategy == "sqrt":
1877+
Weights = 1/(2*np.sqrt(np.abs(Theta)) + 1e-10)
1878+
elif self.strategy == "mcp":
1879+
gamma = 3.
1880+
Weights = np.zeros_like(Theta)
1881+
Weights[np.abs(Theta) < gamma*self.alpha] = (self.alpha -
1882+
np.abs(Theta[np.abs(Theta) < gamma*self.alpha])/gamma)
1883+
else:
1884+
raise ValueError(f"Unknown strategy {self.strategy}")
1885+
18361886
self.n_iter_.append(glasso.n_iter_)
18371887
# TODO print losses for original problem?
18381888
glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True)
18391889
self.precision_ = glasso.precision_
18401890
self.covariance_ = glasso.covariance_
1841-
18421891
return self

skglm/solvers/gram_cd.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from skglm.solvers.base import BaseSolver
77
from skglm.utils.anderson import AndersonAcceleration
8+
from skglm.utils.prox_funcs import ST_vec
89

910

1011
class GramCD(BaseSolver):
@@ -118,7 +119,8 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
118119

119120
# perform Anderson extrapolation
120121
if self.use_acc:
121-
w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad)
122+
w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(
123+
w, grad)
122124

123125
if is_extrapolated:
124126
# omit constant term for comparison
@@ -165,3 +167,35 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd):
165167
grad += (w[j] - old_w_j) * scaled_gram[:, j]
166168

167169
return penalty.subdiff_distance(w, grad, all_features)
170+
171+
172+
@njit
173+
def barebones_cd_gram(H, q, x, alpha, weights, max_iter=100, tol=1e-4):
174+
"""
175+
Solve min .5 * x.T H x + q.T @ x + alpha * norm(x, 1).
176+
177+
H must be symmetric.
178+
"""
179+
dim = H.shape[0]
180+
lc = np.zeros(dim)
181+
for j in range(dim):
182+
lc[j] = H[j, j]
183+
184+
# Hx = H @ x
185+
Hx = np.dot(H, x)
186+
for _ in range(max_iter):
187+
max_delta = 0 # max coeff change
188+
189+
for j in range(dim):
190+
x_j_prev = x[j]
191+
x[j] = ST_vec(x[j] - (Hx[j] + q[j]) / lc[j],
192+
alpha*weights[j] / lc[j])
193+
194+
max_delta = max(max_delta, np.abs(x_j_prev - x[j]))
195+
196+
if x_j_prev != x[j]:
197+
Hx += (x[j] - x_j_prev) * H[j]
198+
if max_delta <= tol:
199+
break
200+
201+
return x

0 commit comments

Comments
 (0)