Skip to content

Commit 63d2f66

Browse files
get rid of lasso.jl backend
1 parent c3e6339 commit 63d2f66

File tree

5 files changed

+25
-54
lines changed

5 files changed

+25
-54
lines changed

src/Maxnet.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using GLM: IdentityLink, CloglogLink, LogitLink, LogLink
88
using MLJModelInterface: Continuous, Binary, Multiclass, Count
99

1010
export IdentityLink, CloglogLink, LogitLink, LogLink # re-export relevant links
11-
export LassoBackend, GLMNetBackend
1211
export maxnet, predict, complexity
1312
export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature
1413
export MaxnetBinaryClassifier

src/lasso.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
1-
abstract type MaxnetBackend end
2-
struct LassoBackend <: MaxnetBackend end
3-
struct GLMNetBackend <: MaxnetBackend end
4-
5-
function fit_lasso_path(
6-
backend::LassoBackend, mm, presences;
7-
kw...)
8-
9-
Lasso.fit(
10-
Lasso.LassoPath, mm, presences, Lasso.Distributions.Binomial();
11-
standardize = false, irls_maxiter = 1_000, kw...)
12-
end
13-
141
function fit_lasso_path(
15-
backend::GLMNetBackend, mm, presences;
2+
mm, presences;
163
wts, penalty_factor, λ, kw...)
174

185
presence_matrix = [1 .- presences presences]
@@ -22,4 +9,3 @@ function fit_lasso_path(
229
end
2310

2411
get_coefs(path::GLMNet.GLMNetPath) = path.betas
25-
get_coefs(path::Lasso.LassoPath) = path.coefs

src/maxnet_function.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
maxnet(
33
presences, predictors;
44
features, regularization_multiplier, regularization_function,
5-
addsamplestobackground, weight_factor, backend,
5+
addsamplestobackground, weight_factor,
66
kw...
77
)
88
@@ -21,9 +21,7 @@
2121
- `addsamplestobackground`: A boolean, where `true` adds the background samples to the predictors. Defaults to `true`.
2222
- `n_knots`: the number of knots used for Threshold and Hinge features. Defaults to 50. Ignored if there are neither Threshold nor Hinge features
2323
- `weight_factor`: A `Float64` value to adjust the weight of the background samples. Defaults to 100.0.
24-
- `backend`: Either `LassoBackend()` or `GLMNetBackend()`, to use either Lasso.jl or GLMNet.jl to fit the model.
25-
Lasso.jl is written in pure julia, but can be slower with large model matrices (e.g. when hinge is enabled). Defaults to `LassoBackend`.
26-
- `kw...`: Further arguments to be passed to `Lasso.fit` or `GLMNet.glmnet`
24+
- `kw...`: Further arguments to be passed to `GLMNet.glmnet`
2725
2826
# Returns
2927
- `model`: A model of type `MaxnetModel`
@@ -32,7 +30,7 @@ Lasso.jl is written in pure julia, but can be slower with large model matrices (
3230
```julia
3331
using Maxnet
3432
p_a, env = Maxnet.bradypus();
35-
bradypus_model = maxnet(p_a, env; features = "lq", backend = GLMNetBackend())
33+
bradypus_model = maxnet(p_a, env; features = "lq")
3634
3735
Fit Maxnet model
3836
Features classes: Maxnet.AbstractFeatureClass[LinearFeature(), CategoricalFeature(), QuadraticFeature()]
@@ -49,7 +47,6 @@ function maxnet(
4947
regularization_function = default_regularization,
5048
addsamplestobackground::Bool = true, weight_factor::Float64 = 100.,
5149
n_knots::Int = 50,
52-
backend::MaxnetBackend = LassoBackend(),
5350
kw...)
5451

5552
_maxnet(
@@ -60,8 +57,7 @@ function maxnet(
6057
regularization_function,
6158
addsamplestobackground,
6259
weight_factor,
63-
n_knots,
64-
backend;
60+
n_knots;
6561
kw...
6662
)
6763
end
@@ -90,8 +86,7 @@ function _maxnet(
9086
regularization_function,
9187
addsamplestobackground::Bool,
9288
weight_factor::Float64,
93-
n_knots::Int,
94-
backend::MaxnetBackend;
89+
n_knots::Int;
9590
kw...)
9691

9792
# check if predictors is a table
@@ -137,7 +132,7 @@ function _maxnet(
137132
λ = lambdas(reg, presences, weights; λmax = 4, n = 200)
138133

139134
# Fit the model
140-
lassopath = fit_lasso_path(backend, mm, presences, wts = weights, penalty_factor = reg, λ = λ)
135+
lassopath = fit_lasso_path(mm, presences, wts = weights, penalty_factor = reg, λ = λ)
141136

142137
# get the coefficients out
143138
coefs = SparseArrays.sparse(get_coefs(lassopath)[:, end])

src/mlj_interface.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ mutable struct MaxnetBinaryClassifier <: MMI.Probabilistic
33
regularization_multiplier::Float64
44
regularization_function
55
weight_factor::Float64
6-
backend::MaxnetBackend
76
link::GLM.Link
87
clamp::Bool
98
kw
@@ -12,14 +11,14 @@ end
1211
function MaxnetBinaryClassifier(;
1312
features="",
1413
regularization_multiplier = 1.0, regularization_function = default_regularization,
15-
weight_factor = 100., backend = LassoBackend(),
14+
weight_factor = 100.,
1615
link = CloglogLink(), clamp = false,
1716
kw...
1817
)
1918

2019
MaxnetBinaryClassifier(
2120
features, regularization_multiplier, regularization_function,
22-
weight_factor, backend, link, clamp, kw
21+
weight_factor, link, clamp, kw
2322
)
2423
end
2524

@@ -63,7 +62,6 @@ function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
6362
regularization_multiplier = m.regularization_multiplier,
6463
regularization_function = m.regularization_function,
6564
weight_factor = m.weight_factor,
66-
backend = m.backend,
6765
m.kw...)
6866

6967
decode = MMI.classes(y)

test/runtests.jl

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,34 @@ env1 = map(e -> [e[1]], env) # just the first row
1717
end
1818

1919
@testset "Maxnet" begin
20-
# test both backends work
21-
model_glmnet = Maxnet.maxnet((p_a), env; features = "lq", backend = GLMNetBackend());
22-
model_lasso = Maxnet.maxnet((p_a), env; features = "lq", backend = LassoBackend());
23-
24-
# test both backends come up with approximately the same result
25-
@test all(isapprox.(model_glmnet.coefs, model_lasso.coefs; rtol = 0.1, atol = 0.1))
26-
@test Statistics.cor(model_glmnet.coefs, model_lasso.coefs) > 0.99
20+
# some class combinations and keywords
21+
m = Maxnet.maxnet(p_a, env; features = "lq");
22+
Maxnet.maxnet(p_a, env; features = "lqp", regularization_multiplier = 2.);
23+
Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., nknots = 10);
24+
Maxnet.maxnet(p_a, env; features = "lqph", weight_factor = 10.);
2725

2826
# test the result
29-
@test model_glmnet.entropy 6.114650341746531
30-
@test complexity(model_glmnet) == 21
31-
32-
# some class combinations and keywords
33-
Maxnet.maxnet(p_a, env; features = "lq", backend = LassoBackend());
34-
Maxnet.maxnet(p_a, env; features = "lqp", regularization_multiplier = 2., backend = LassoBackend());
35-
Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., nknots = 10, backend = LassoBackend());
36-
Maxnet.maxnet(p_a, env; features = "lqph", weight_factor = 10., backend = LassoBackend());
27+
@test m.entropy 6.114650341746531
28+
@test complexity(m) == 21
3729

3830
# predictions
39-
prediction = Maxnet.predict(model_lasso, env)
31+
prediction = Maxnet.predict(m, env)
4032
@test Statistics.mean(prediction[p_a]) > Statistics.mean(prediction[.~p_a])
4133
@test minimum(prediction) > 0.
4234
@test maximum(prediction) < 1.
43-
@test mean(prediction) 0.243406167194403 atol=1e-4
35+
@test mean(prediction) 0.24375837576014572 atol=1e-4
4436

4537
# check that clamping works
4638
# clamp shouldn't change anything in this case
47-
@test prediction == Maxnet.predict(model_lasso, env; clamp = true)
39+
@test prediction == Maxnet.predict(m, env; clamp = true)
4840

4941
# predict with a crazy extrapolation
5042
env1_extrapolated = merge(env1, (;cld6190_ann = [100_000]))
5143
env1_max_cld = merge(env1, (;cld6190_ann = [maximum(env.cld6190_ann)]))
5244

5345
# using clamp the prediction uses the highest cloud
54-
@test Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink(), clamp = true) ==
55-
Maxnet.predict(model_lasso, env1_max_cld; link = IdentityLink())
46+
@test Maxnet.predict(m, env1_extrapolated; link = IdentityLink(), clamp = true) ==
47+
Maxnet.predict(m, env1_max_cld; link = IdentityLink())
5648
end
5749

5850
@testset "MLJ" begin
@@ -63,21 +55,22 @@ end
6355
env_typed = MLJBase.coerce(env, cont_keys...)
6456

6557
# make a machine
66-
mach1 = machine(mn(features = "lq", backend = LassoBackend()), env_typed, categorical(p_a))
58+
mach1 = machine(mn(features = "lq"), env_typed, categorical(p_a))
6759
fit!(mach1)
6860

69-
mach2 = machine(mn(features = "lqph", backend = GLMNetBackend()), env_typed, categorical(p_a))
61+
mach2 = machine(mn(features = "lqph"), env_typed, categorical(p_a))
7062
fit!(mach2)
7163

7264
# make the equivalent model without mlj
73-
model = Maxnet.maxnet((p_a), env_typed; features = "lqph", backend = GLMNetBackend());
65+
model = Maxnet.maxnet((p_a), env_typed; features = "lqph");
7466

7567

7668
# predict via MLJBase
7769
mljprediction = MLJBase.predict(mach2, env_typed)
7870
mlj_true_probability = pdf.(mljprediction, true)
7971

8072
# test that this predicts the same as the equivalent model without mlj
73+
8174
@test all(Maxnet.predict(model, env_typed) .≈ mlj_true_probability)
8275

8376
@test Statistics.mean(mlj_true_probability[p_a]) > Statistics.mean(mlj_true_probability[.~p_a])

0 commit comments

Comments
 (0)