Skip to content

Commit 63d32f2

Browse files
Better tests (#2)
* add more tests * separate file for MaxnetModel and its methods * add an example to predict * specify types of fields in MaxnetModel * fix the order of loading jl files * slightly increase tolerance so test passes
1 parent 3a0b72a commit 63d32f2

File tree

5 files changed

+65
-42
lines changed

5 files changed

+65
-42
lines changed

src/Maxnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ using MLJModelInterface: Continuous, Binary, Multiclass, Count
99

1010
export IdentityLink, CloglogLink, LogitLink, LogLink # re-export relevant links
1111
export LassoBackend, GLMNetBackend
12-
export maxnet, predict
12+
export maxnet, predict, complexity
1313
export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature
1414
export MaxnetBinaryClassifier
1515

16-
# Write your package code here.
1716

1817
include("utils.jl")
1918
include("lasso.jl")
2019
include("feature_classes.jl")
2120
include("model_matrix.jl")
2221
include("regularization.jl")
22+
include("MaxnetModel.jl")
2323
include("maxnet_function.jl")
2424
include("predict.jl")
2525
include("response_curves.jl")

src/MaxnetModel.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
struct MaxnetModel
2+
path::Union{GLMNet.GLMNetPath, Lasso.LassoPath}
3+
features::Vector{<:AbstractFeatureClass}
4+
columns::Vector{ModelMatrixColumn}
5+
coefs::AbstractVector
6+
alpha::Float64
7+
entropy::Float64
8+
predictor_data
9+
categorical_predictors::NTuple{<:Any, Symbol}
10+
continuous_predictors::NTuple{<:Any, Symbol}
11+
end
12+
13+
function Base.show(io::IO, mime::MIME"text/plain", m::MaxnetModel)
14+
vars_selected = mapreduce(Maxnet._var_keys, (x, y) -> unique(vcat(x, y)), selected_features(m))
15+
16+
println(io, "Fit Maxnet model")
17+
18+
println(io, "Features classes: $(m.features)")
19+
println(io, "Entropy: $(m.entropy)")
20+
println(io, "Model complexity: $(complexity(m))")
21+
println(io, "Variables selected: $vars_selected")
22+
end
23+
24+
"Get the number of non-zero coefficients in the model"
25+
complexity(m::MaxnetModel) = length(m.coefs.nzval)

src/maxnet_function.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,3 @@
1-
struct MaxnetModel
2-
path
3-
features
4-
columns
5-
coefs
6-
alpha
7-
entropy
8-
predictor_data
9-
categorical_predictors
10-
continuous_predictors
11-
end
12-
13-
function Base.show(io::IO, mime::MIME"text/plain", m::MaxnetModel)
14-
vars_selected = mapreduce(Maxnet._var_keys, (x, y) -> unique(vcat(x, y)), selected_features(m))
15-
16-
println(io, "Fit Maxnet model")
17-
18-
println(io, "Features classes: $(m.features)")
19-
println(io, "Entropy: $(m.entropy)")
20-
println(io, "Model complexity: $(length(m.coefs.nzval))")
21-
println(io, "Variables selected: $vars_selected")
22-
end
23-
241
"""
252
maxnet(
263
presences, predictors;
@@ -52,14 +29,11 @@ Lasso.jl is written in pure julia, but can be slower with large model matrices (
5229
- `model`: A model of type `MaxnetModel`
5330
5431
# Examples
55-
```jldoctest
32+
```julia
5633
using Maxnet
57-
p_a, env = Maxnet.bradypus()
58-
34+
p_a, env = Maxnet.bradypus();
5935
bradypus_model = maxnet(p_a, env; features = "lq", backend = GLMNetBackend())
6036
61-
# output
62-
6337
Fit Maxnet model
6438
Features classes: Maxnet.AbstractFeatureClass[LinearFeature(), CategoricalFeature(), QuadraticFeature()]
6539
Entropy: 6.114650341746531

src/predict.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
# Returns
1818
A `Vector` with the resulting predictions.
1919
20+
# Example
21+
```julia
22+
using Maxnet
23+
p_a, env = Maxnet.bradypus();
24+
bradypus_model = maxnet(p_a, env; features = "lq")
25+
prediction = Maxnet.predict(bradypus_model, env)
26+
```
2027
"""
2128
function predict(m::MaxnetModel, x; link = CloglogLink(), clamp = false)
2229
predictors = Tables.columntable(x)

test/runtests.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
using Maxnet, Test, MLJBase, Statistics
22

33
p_a, env = Maxnet.bradypus()
4-
env1 = map(e -> [e[1]], env)
4+
env1 = map(e -> [e[1]], env) # just the first row
55

66
@testset "utils" begin
77
@test_throws ErrorException Maxnet.features_from_string("a")
8+
# test each feature class is returned correctly
89
@test Maxnet.features_from_string("l") == [LinearFeature(), CategoricalFeature()]
910
@test Maxnet.features_from_string("q") == [QuadraticFeature()]
11+
@test Maxnet.features_from_string("lq") == [LinearFeature(), CategoricalFeature(), QuadraticFeature()]
12+
@test Maxnet.features_from_string("lqp") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature()]
13+
@test Maxnet.features_from_string("lqph") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature(), HingeFeature()]
14+
@test Maxnet.features_from_string("lqpt") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature(), ThresholdFeature()]
15+
16+
@test Maxnet.default_features(100) == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), HingeFeature(), ProductFeature()]
1017
end
1118

1219
@testset "Maxnet" begin
@@ -18,31 +25,34 @@ end
1825
@test all(isapprox.(model_glmnet.coefs, model_lasso.coefs; rtol = 0.1, atol = 0.1))
1926
@test Statistics.cor(model_glmnet.coefs, model_lasso.coefs) > 0.99
2027

21-
# select classes automatically
22-
Maxnet.maxnet(p_a, env; backend = LassoBackend());
28+
# test the result
29+
@test model_glmnet.entropy 6.114650341746531
30+
@test complexity(model_glmnet) == 21
2331

24-
# some class combinations
32+
# some class combinations and keywords
2533
Maxnet.maxnet(p_a, env; features = "lq", backend = LassoBackend());
2634
Maxnet.maxnet(p_a, env; features = "lqp", regularization_multiplier = 2., backend = LassoBackend());
27-
Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., backend = LassoBackend());
28-
Maxnet.maxnet(p_a, env; features = "lqph", backend = LassoBackend());
29-
Maxnet.maxnet(p_a, env; features = "lqpt", backend = LassoBackend());
35+
Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., nknots = 10, backend = LassoBackend());
36+
Maxnet.maxnet(p_a, env; features = "lqph", weight_factor = 10., backend = LassoBackend());
3037

3138
# predictions
3239
prediction = Maxnet.predict(model_lasso, env)
3340
@test Statistics.mean(prediction[p_a]) > Statistics.mean(prediction[.~p_a])
3441
@test minimum(prediction) > 0.
3542
@test maximum(prediction) < 1.
43+
@test mean(prediction) 0.243406167194403 atol=1e-4
3644

45+
# check that clamping works
3746
# clamp shouldn't change anything in this case
38-
@test all(prediction .== Maxnet.predict(model_lasso, env; clamp = true))
47+
@test prediction == Maxnet.predict(model_lasso, env; clamp = true)
3948

4049
# predict with a crazy extrapolation
4150
env1_extrapolated = merge(env1, (;cld6190_ann = [100_000]))
42-
# without clamp the prediction is crazy
43-
@test abs(Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink())[1]) > 100_000.
44-
# without clamp the prediction is reasonable
45-
@test abs(Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink(), clamp = true)[1]) < 5.
51+
env1_max_cld = merge(env1, (;cld6190_ann = [maximum(env.cld6190_ann)]))
52+
53+
# using clamp the prediction uses the highest cloud
54+
@test Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink(), clamp = true) ==
55+
Maxnet.predict(model_lasso, env1_max_cld; link = IdentityLink())
4656
end
4757

4858
@testset "MLJ" begin
@@ -59,10 +69,17 @@ end
5969
mach2 = machine(mn(features = "lqph", backend = GLMNetBackend()), env_typed, categorical(p_a))
6070
fit!(mach2)
6171

72+
# make the equivalent model without mlj
73+
model = Maxnet.maxnet((p_a), env_typed; features = "lqph", backend = GLMNetBackend());
74+
75+
6276
# predict via MLJBase
6377
mljprediction = MLJBase.predict(mach2, env_typed)
6478
mlj_true_probability = pdf.(mljprediction, true)
6579

80+
# test that this predicts the same as the equivalent model without mlj
81+
@test all(Maxnet.predict(model, env_typed) .≈ mlj_true_probability)
82+
6683
@test Statistics.mean(mlj_true_probability[p_a]) > Statistics.mean(mlj_true_probability[.~p_a])
6784
@test minimum(mlj_true_probability) > 0.
6885
@test maximum(mlj_true_probability) < 1.

0 commit comments

Comments
 (0)