Skip to content

Commit 93a2c31

Browse files
Pass keywords to glmnet (#19)
* actually pass keywords to glmnet * test passing of keywords to glmnet
1 parent bf76f47 commit 93a2c31

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

src/lasso.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
function fit_lasso_path(
22
mm, presences;
3-
wts, penalty_factor, λ, kw...)
3+
weights, penalty_factor, lambda, kw...)
44

55
presence_matrix = [1 .- presences presences]
66
GLMNet.glmnet(
77
mm, presence_matrix, GLMNet.Binomial();
8-
weights = wts, penalty_factor = penalty_factor, lambda = λ, standardize = false)
8+
weights, penalty_factor, lambda, standardize = false, kw...)
99
end
1010

1111
get_coefs(path::GLMNet.GLMNetPath) = path.betas

src/maxnet_function.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ function maxnet(
6161
kw...
6262
)
6363
end
64-
#maxnet(presences, predictors; kw...) = maxnet(presences, predictors, features; kw...)
6564

6665
### internal methods where features is not a keyword
6766

@@ -129,10 +128,10 @@ function _maxnet(
129128
weights = presences .* 1. .+ (1 .- presences) .* weight_factor
130129

131130
# generate lambdas
132-
λ = lambdas(reg, presences, weights; λmax = 4, n = 200)
131+
lambda = lambdas(reg, presences, weights; λmax = 4, n = 200)
133132

134133
# Fit the model
135-
lassopath = fit_lasso_path(mm, presences, wts = weights, penalty_factor = reg, λ = λ)
134+
lassopath = fit_lasso_path(mm, presences; weights, penalty_factor = reg, lambda, kw...)
136135

137136
# get the coefficients out
138137
coefs = SparseArrays.sparse(get_coefs(lassopath)[:, end])

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ env1 = map(e -> [e[1]], env) # just the first row
3131
@test predictors == (a = [1,2,3,1], b = [1,2,3,1])
3232
end
3333

34+
3435
@testset "Maxnet" begin
3536
# some class combinations and keywords
3637
m = maxnet(p_a, env; features = "lq");
@@ -72,7 +73,14 @@ end
7273
@test complexity(empty_model) == 0
7374
@test Maxnet.selected_features(empty_model) == Symbol[]
7475
@test length(unique(predict(empty_model, env))) == 1
76+
77+
# test that keywords arguments are passed to glmnet
78+
weights = ifelse.(p_a, 1.0, 10.0)
79+
m_w = maxnet(p_a, env; features = "lq", addsamplestobackground = false, weights)
80+
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
81+
@test m_w.entropy > m.entropy
7582
end
83+
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
7684

7785
@testset "MLJ" begin
7886
using MLJBase

0 commit comments

Comments
 (0)