Skip to content

Commit 381ad3c

Browse files
committed
DOC: update lightning documentation
1 parent d820862 commit 381ad3c

File tree

66 files changed

+1192
-205
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1192
-205
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
=====================================
3+
Signal recovery by 1D total variation
4+
=====================================
5+
6+
In this example, we generate a signal that is piecewise constant. We then
7+
observe some random and corrupted measurements from that signal and
8+
then try to recover that signal using L1 and TV1D penalties.
9+
10+
Given a ground truth vectors, the signal that we observe is given by
11+
12+
y = sign(X ground_truth + noise)
13+
14+
where X is a random matrix. We obtain the vector ground_truth by solving
15+
an optimization problem using lightning FistaClassifier.
16+
17+
The 1D total variation is also known as fused lasso.
18+
"""
19+
20+
import numpy as np
21+
import matplotlib.pyplot as plt
22+
from lightning.classification import FistaClassifier
23+
from sklearn.grid_search import GridSearchCV
24+
25+
# generate some synthetic data
26+
n_samples = 200
27+
ground_truth = np.concatenate((
28+
np.ones(20), - np.ones(20), np.zeros(40)))
29+
n_features = ground_truth.size
30+
np.random.seed(0) # for reproducibility
31+
X = np.random.rand(n_samples, n_features)
32+
# generate y as a linear model, y = sign(X w + noise)
33+
y = np.sign(X.dot(ground_truth) + 0.5 * np.random.randn(n_samples)).astype(np.int)
34+
35+
36+
for penalty in ('l1', 'tv1d'):
37+
clf = FistaClassifier(penalty=penalty)
38+
gs = GridSearchCV(clf, {'alpha': np.logspace(-3, 3, 10)})
39+
gs.fit(X, y)
40+
coefs = gs.best_estimator_.coef_
41+
plt.plot(coefs.ravel(), label='%s penalty' % penalty, lw=3)
42+
43+
plt.plot(ground_truth, lw=3, marker='^', markevery=5, markersize=10, label="ground truth")
44+
plt.grid()
45+
plt.legend()
46+
plt.ylim((-1.5, 1.5))
47+
plt.show()

lightning/_downloads/plot_l2_solvers.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lightning.classification import SDCAClassifier
2121
from lightning.classification import CDClassifier
2222
from lightning.classification import AdaGradClassifier
23-
from lightning.classification import SAGClassifier
23+
from lightning.classification import SAGAClassifier, SAGClassifier
2424

2525
from lightning.impl.adagrad_fast import _proj_elastic_all
2626

@@ -66,10 +66,8 @@ def __call__(self, clf, t=None):
6666
y[y >= 1] = 1
6767
alpha = 1e-4
6868
eta_svrg = 1e-1
69-
eta_sag = 1
7069
eta_adagrad = 1
71-
xlim = (0, 4)
72-
ylim = (0.04, 0.1)
70+
xlim = (0, 20)
7371

7472
else:
7573
X, y = make_classification(n_samples=10000,
@@ -78,32 +76,38 @@ def __call__(self, clf, t=None):
7876
random_state=0)
7977
alpha = 1e-2
8078
eta_svrg = 1e-3
81-
eta_sag = 1e-3
8279
eta_adagrad = 1e-2
83-
xlim = None
84-
ylim = (0.5, 0.6)
80+
xlim = [0, 2]
8581

8682
y = y * 2 - 1
8783

84+
# make sure the method does not stop prematurely, we want to see
85+
# the full convergence path
86+
tol = 1e-24
8887

8988
clf1 = SVRGClassifier(loss="squared_hinge", alpha=alpha, eta=eta_svrg,
90-
n_inner=1.0, max_iter=50, random_state=0)
89+
n_inner=1.0, max_iter=100, random_state=0, tol=1e-24)
9190
clf2 = SDCAClassifier(loss="squared_hinge", alpha=alpha,
92-
max_iter=50, n_calls=X.shape[0]/2, random_state=0)
91+
max_iter=100, n_calls=X.shape[0]/2, random_state=0, tol=tol)
9392
clf3 = CDClassifier(loss="squared_hinge", alpha=alpha, C=1.0/X.shape[0],
94-
max_iter=50, n_calls=X.shape[1]/3, random_state=0)
93+
max_iter=50, n_calls=X.shape[1]/3, random_state=0, tol=tol)
9594
clf4 = AdaGradClassifier(loss="squared_hinge", alpha=alpha, eta=eta_adagrad,
96-
n_iter=50, n_calls=X.shape[0]/2, random_state=0)
97-
clf5 = SAGClassifier(loss="squared_hinge", alpha=alpha, eta=eta_sag,
98-
max_iter=50, random_state=0)
95+
n_iter=100, n_calls=X.shape[0]/2, random_state=0)
96+
clf5 = SAGAClassifier(loss="squared_hinge", alpha=alpha,
97+
max_iter=100, random_state=0, tol=tol)
98+
clf6 = SAGClassifier(loss="squared_hinge", alpha=alpha,
99+
max_iter=100, random_state=0, tol=tol)
99100

100101
plt.figure()
101102

103+
data = {}
102104
for clf, name in ((clf1, "SVRG"),
103105
(clf2, "SDCA"),
104106
(clf3, "PCD"),
105107
(clf4, "AdaGrad"),
106-
(clf5, "SAG")):
108+
(clf5, "SAGA"),
109+
(clf6, "SAG")
110+
):
107111
print(name)
108112
cb = Callback(X, y)
109113
clf.callback = cb
@@ -112,13 +116,18 @@ def __call__(self, clf, t=None):
112116
clf.fit(X.tocsc(), y)
113117
else:
114118
clf.fit(X, y)
119+
data[name] = (cb.times, np.array(cb.obj))
115120

116-
plt.plot(cb.times, cb.obj, label=name)
121+
# get best value
122+
fmin = min([np.min(a[1]) for a in data.values()])
123+
for name in data:
124+
plt.plot(data[name][0], data[name][1] - fmin, label=name, lw=3)
117125

118126
plt.xlim(xlim)
119-
plt.ylim(ylim)
127+
plt.yscale('log')
120128
plt.xlabel("CPU time")
121-
plt.ylabel("Objective value")
129+
plt.ylabel("Objective value minus optimum")
122130
plt.legend()
131+
plt.grid()
123132

124133
plt.show()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
======================
3+
SAGA: Weighted samples
4+
======================
5+
6+
Plot decision function of a weighted dataset, where the size of points
7+
is proportional to its weight.
8+
9+
Adapted from scikit-learn's plot_sgd_weighted_samples.py
10+
"""
11+
print(__doc__)
12+
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
from lightning.impl.sag import SAGAClassifier
16+
17+
# we create 20 points
18+
np.random.seed(0)
19+
X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
20+
y = np.array([1] * 10 + [-1] * 10)
21+
sample_weight = 100 * np.abs(np.random.randn(20))
22+
# and assign a bigger weight to the last 10 samples
23+
sample_weight[:10] *= 10
24+
25+
# plot the weighted data points
26+
xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))
27+
plt.figure()
28+
plt.scatter(X[:, 0], X[:, 1], c=y, s=sample_weight, alpha=0.9,
29+
cmap=plt.cm.bone)
30+
31+
# fit the unweighted model
32+
clf = SAGAClassifier(alpha=0.01, loss='log')
33+
clf.fit(X, y)
34+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
35+
Z = Z.reshape(xx.shape)
36+
no_weights = plt.contour(xx, yy, Z, levels=[0], linestyles=['solid'])
37+
38+
# fit the weighted model
39+
clf = SAGAClassifier(alpha=0.01, loss='log')
40+
clf.fit(X, y, sample_weight=sample_weight)
41+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
42+
Z = Z.reshape(xx.shape)
43+
samples_weights = plt.contour(xx, yy, Z, levels=[0], linestyles=['dashed'])
44+
45+
plt.legend([no_weights.collections[0], samples_weights.collections[0]],
46+
["no weights", "with weights"], loc="lower left")
47+
48+
plt.xticks(())
49+
plt.yticks(())
50+
plt.show()
43.2 KB
Loading
54.4 KB
Loading
27.1 KB
Loading
38.9 KB
Loading
19.6 KB
Loading
41 KB
Loading

lightning/_modules/index.html

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@
7373
role="menu"
7474
aria-labelledby="dLabelGlobalToc"><ul>
7575
<li class="toctree-l1"><a class="reference internal" href="../auto_examples/index.html">Examples</a><ul>
76+
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_1d_total_variation.html">Signal recovery by 1D total variation</a></li>
7677
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_sgd_loss_functions.html">SGD: Convex Loss Functions</a></li>
7778
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_robust_regression.html">Robust regression</a></li>
7879
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/trace.html">Trace norm</a></li>
80+
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_sample_weight.html">SAGA: Weighted samples</a></li>
7981
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/document_classification_news20.html">Classification of text documents</a></li>
8082
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_svrg.html">Sensitivity to hyper-parameters in SVRG</a></li>
8183
<li class="toctree-l2"><a class="reference internal" href="../auto_examples/plot_sparse_non_linear.html">Sparse non-linear classification</a></li>
@@ -118,13 +120,15 @@
118120
<li class="toctree-l2"><a class="reference internal" href="../intro.html#fista">FISTA</a></li>
119121
<li class="toctree-l2"><a class="reference internal" href="../intro.html#stochastic-gradient-method-sgd">Stochastic gradient method (SGD)</a></li>
120122
<li class="toctree-l2"><a class="reference internal" href="../intro.html#adagrad">AdaGrad</a></li>
121-
<li class="toctree-l2"><a class="reference internal" href="../intro.html#stochastic-averaged-gradient-sag">Stochastic averaged gradient (SAG)</a></li>
123+
<li class="toctree-l2"><a class="reference internal" href="../intro.html#stochastic-averaged-gradient-sag-and-saga">Stochastic averaged gradient (SAG and SAGA)</a></li>
122124
<li class="toctree-l2"><a class="reference internal" href="../intro.html#stochastic-variance-reduced-gradient-svrg">Stochastic variance-reduced gradient (SVRG)</a></li>
123125
<li class="toctree-l2"><a class="reference internal" href="../intro.html#prank">PRank</a><ul>
124126
<li class="toctree-l3"><a class="reference internal" href="../auto_examples/index.html">Examples</a><ul>
127+
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_1d_total_variation.html">Signal recovery by 1D total variation</a></li>
125128
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_sgd_loss_functions.html">SGD: Convex Loss Functions</a></li>
126129
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_robust_regression.html">Robust regression</a></li>
127130
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/trace.html">Trace norm</a></li>
131+
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_sample_weight.html">SAGA: Weighted samples</a></li>
128132
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/document_classification_news20.html">Classification of text documents</a></li>
129133
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_svrg.html">Sensitivity to hyper-parameters in SVRG</a></li>
130134
<li class="toctree-l4"><a class="reference internal" href="../auto_examples/plot_sparse_non_linear.html">Sparse non-linear classification</a></li>

0 commit comments

Comments
 (0)