Skip to content

Commit adb4cc8

Browse files
Merge pull request #16 from tiemvanderdeure/statsapi_predict
use StatsAPI.predict, add mlj metadata, better tests, some other minor changes
2 parents 60f4501 + 29ee3ca commit adb4cc8

File tree

7 files changed

+77
-26
lines changed

7 files changed

+77
-26
lines changed

docs/src/usage/quickstart.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bradypus_model = maxnet(p_a, env)
2626
prediction = predict(bradypus_model, env)
2727
```
2828

29-
There are numerous settings that can be tweaked to change the model fit. These are documentated in the documentatoin for the `maxnet`(@ref) and `Maxnet.predict`(@ref) functions.
29+
There are numerous settings that can be tweaked to change the model fit. These are documentated in the documentation for the `maxnet`(@ref) and `predict`(@ref) functions.
3030

3131
### Model settings
3232
The two most important settings to change when running Maxnet is the feature classes selected and the regularization factor.

src/Maxnet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ import Tables, Lasso, GLMNet, Interpolations, CategoricalArrays, GLM, SparseArra
44
import StatsAPI, StatsBase, Statistics
55
import MLJModelInterface as MMI
66

7+
using StatsAPI: predict
78
using GLM: IdentityLink, CloglogLink, LogitLink, LogLink
89
using MLJModelInterface: Continuous, Binary, Multiclass, Count
910

1011
export IdentityLink, CloglogLink, LogitLink, LogLink # re-export relevant links
1112
export maxnet, predict, complexity
12-
export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature
13+
export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature, AbstractFeatureClass
1314
export MaxnetBinaryClassifier
1415

1516

src/maxnet_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Variables selected: [:frs6190_ann, :h_dem, :pre6190_l1, :pre6190_l10, :pre6190_l
4343
function maxnet(
4444
presences::BitVector, predictors;
4545
features = default_features(sum(presences)),
46-
regularization_multiplier::Float64 = 1.0,
46+
regularization_multiplier = 1.0,
4747
regularization_function = default_regularization,
4848
addsamplestobackground::Bool = true, weight_factor::Float64 = 100.,
4949
n_knots::Int = 50,
@@ -82,7 +82,7 @@ function _maxnet(
8282
presences::BitVector,
8383
predictors,
8484
features::Vector{<:AbstractFeatureClass},
85-
regularization_multiplier::Float64,
85+
regularization_multiplier,
8686
regularization_function,
8787
addsamplestobackground::Bool,
8888
weight_factor::Float64,

src/mlj_interface.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,24 @@ end
4646
"""
4747
MaxnetBinaryClassifier
4848

49-
MMI.input_scitype(::Type{<:MaxnetBinaryClassifier}) =
50-
MMI.Table{<:Union{<:AbstractVector{<:Continuous}, <:AbstractVector{<:Multiclass}}} #{<:Union{<:Continuous <:Multiclass}}
51-
52-
MMI.target_scitype(::Type{<:MaxnetBinaryClassifier}) = AbstractVector{Multiclass{2}}# AbstractVector{<:MMI.Finite}
49+
MMI.metadata_pkg(
50+
MaxnetBinaryClassifier;
51+
name = "Maxnet",
52+
uuid = "81f79f80-22f2-4e41-ab86-00c11cf0f26f",
53+
url = "https://github.com/tiemvanderdeure/Maxnet.jl",
54+
is_pure_julia = false,
55+
package_license = "MIT",
56+
is_wrapper = false
57+
)
5358

54-
MMI.load_path(::Type{<:MaxnetBinaryClassifier}) = "Maxnet.MaxnetBinaryClassifier"
59+
MMI.metadata_model(
60+
MaxnetBinaryClassifier;
61+
input_scitype = MMI.Table(MMI.Continuous, MMI.Finite),
62+
target_scitype = AbstractVector{<:MMI.Finite{2}},
63+
load_path = "Maxnet.MaxnetBinaryClassifier",
64+
human_name = "Maxnet",
65+
reports_feature_importances=false
66+
)
5567

5668
function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
5769
# convert categorical to boolean

src/predict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ A `Vector` with the resulting predictions.
2222
using Maxnet
2323
p_a, env = Maxnet.bradypus();
2424
bradypus_model = maxnet(p_a, env; features = "lq")
25-
prediction = Maxnet.predict(bradypus_model, env)
25+
prediction = predict(bradypus_model, env)
2626
```
2727
"""
28-
function predict(m::MaxnetModel, x; link = CloglogLink(), clamp = false)
28+
function StatsAPI.predict(m::MaxnetModel, x; link = CloglogLink(), clamp = false)
2929
predictors = Tables.columntable(x)
3030
for k in keys(m.predictor_data)
3131
k in keys(predictors) || error("$k is not found in the predictors")

src/types.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct MaxnetModel
2-
path::Union{GLMNet.GLMNetPath, Lasso.LassoPath}
2+
path::GLMNet.GLMNetPath
33
features::Vector{<:AbstractFeatureClass}
44
columns::Vector{ModelMatrixColumn}
55
coefs::AbstractVector
@@ -11,7 +11,7 @@ struct MaxnetModel
1111
end
1212

1313
function Base.show(io::IO, mime::MIME"text/plain", m::MaxnetModel)
14-
vars_selected = mapreduce(Maxnet._var_keys, (x, y) -> unique(vcat(x, y)), selected_features(m))
14+
vars_selected = mapreduce(Maxnet._var_keys, (x, y) -> unique(vcat(x, y)), selected_features(m); init = Symbol[])
1515

1616
println(io, "Fit Maxnet model")
1717

test/runtests.jl

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Maxnet, Test, MLJBase, Statistics
1+
using Maxnet, Test, Statistics
22

33
p_a, env = Maxnet.bradypus()
44
env1 = map(e -> [e[1]], env) # just the first row
@@ -14,42 +14,81 @@ env1 = map(e -> [e[1]], env) # just the first row
1414
@test Maxnet.features_from_string("lqpt") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature(), ThresholdFeature()]
1515

1616
@test Maxnet.default_features(100) == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), HingeFeature(), ProductFeature()]
17+
18+
@test Maxnet.hinge(1:5, 3) == [
19+
# 1:5 3:5 1:3 1:5
20+
0.0 0.0 0.0 0.0
21+
0.25 0.0 0.5 0.25
22+
0.5 0.0 1.0 0.5
23+
0.75 0.5 1.0 0.75
24+
1.0 1.0 1.0 1.0
25+
]
26+
@test size(Maxnet.hinge(1:200)) == (200, 98)
27+
28+
presence, predictors = Maxnet.addsamples([true, false, false], (a = [1,2,3], b = [1,2,3]))
29+
@test presence == [true, false, false, false]
30+
@test predictors == (a = [1,2,3,1], b = [1,2,3,1])
1731
end
1832

1933
@testset "Maxnet" begin
2034
# some class combinations and keywords
21-
m = Maxnet.maxnet(p_a, env; features = "lq");
22-
Maxnet.maxnet(p_a, env; features = "lqp", regularization_multiplier = 2.);
23-
Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., nknots = 10);
24-
Maxnet.maxnet(p_a, env; features = "lqph", weight_factor = 10.);
35+
m = maxnet(p_a, env; features = "lq");
36+
m2 = maxnet(p_a, env)
37+
m3 = maxnet(p_a, env[(:cld6190_ann, :h_dem)])
38+
m4 = maxnet(p_a, env[(:ecoreg,)], addsamplestobackground =false)
39+
m5 = maxnet(p_a, env[(:cld6190_ann, :h_dem)]; features = "ht", n_knots = 3)
2540

26-
# test the result
41+
# test the results
2742
@test m.entropy 6.114650341746531
2843
@test complexity(m) == 21
29-
44+
@test m2.features == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), HingeFeature(), ProductFeature()]
45+
@test m3.features == [LinearFeature(), QuadraticFeature(), HingeFeature(), ProductFeature()]
46+
@test m4.features == [CategoricalFeature()]
47+
@test m5.features == [HingeFeature(), ThresholdFeature()]
48+
@test length(m5.columns) == 14 # (n-1)*2 hinge columns and n threshold columns for each variable
49+
3050
# predictions
31-
prediction = Maxnet.predict(m, env)
51+
prediction = predict(m, env)
3252
@test Statistics.mean(prediction[p_a]) > Statistics.mean(prediction[.~p_a])
3353
@test minimum(prediction) > 0.
3454
@test maximum(prediction) < 1.
3555
@test mean(prediction) 0.24375837576014572 atol=1e-4
3656

3757
# check that clamping works
3858
# clamp shouldn't change anything in this case
39-
@test prediction == Maxnet.predict(m, env; clamp = true)
59+
@test prediction == predict(m, env; clamp = true)
4060

4161
# predict with a crazy extrapolation
4262
env1_extrapolated = merge(env1, (;cld6190_ann = [100_000]))
4363
env1_max_cld = merge(env1, (;cld6190_ann = [maximum(env.cld6190_ann)]))
4464

4565
# using clamp the prediction uses the highest cloud
46-
@test Maxnet.predict(m, env1_extrapolated; link = IdentityLink(), clamp = true) ==
47-
Maxnet.predict(m, env1_max_cld; link = IdentityLink())
66+
@test predict(m, env1_extrapolated; link = IdentityLink(), clamp = true) ==
67+
predict(m, env1_max_cld; link = IdentityLink())
68+
69+
# test that maxnet works if no features are selected
70+
empty_model = maxnet(p_a, env; regularization_multiplier = 1000);
71+
@test complexity(empty_model) == 0
72+
@test Maxnet.selected_features(empty_model) == Symbol[]
73+
@test length(unique(predict(empty_model, env))) == 1
4874
end
4975

5076
@testset "MLJ" begin
77+
using MLJBase
5178
mn = Maxnet.MaxnetBinaryClassifier
5279

80+
# Test model metadata
81+
@test name(mn) == "MaxnetBinaryClassifier"
82+
@test human_name(mn) == "Maxnet"
83+
@test package_name(mn) == "Maxnet"
84+
@test !supports_weights(mn)
85+
@test !is_pure_julia(mn)
86+
@test is_supervised(mn)
87+
@test package_license(mn) == "MIT"
88+
@test prediction_type(mn) == :probabilistic
89+
@test input_scitype(mn) == Table{<:Union{AbstractVector{<:Continuous}, AbstractVector{<:Finite}}}
90+
@test hyperparameters(mn) == (:features, :regularization_multiplier, :regularization_function, :weight_factor, :link, :clamp, :kw)
91+
5392
# convert to continuous
5493
cont_keys = collect(key => Continuous for key in keys(env) if key !== :ecoreg)
5594
env_typed = MLJBase.coerce(env, cont_keys...)
@@ -64,14 +103,13 @@ end
64103
# make the equivalent model without mlj
65104
model = Maxnet.maxnet((p_a), env_typed; features = "lqph");
66105

67-
68106
# predict via MLJBase
69107
mljprediction = MLJBase.predict(mach2, env_typed)
70108
mlj_true_probability = pdf.(mljprediction, true)
71109

72110
# test that this predicts the same as the equivalent model without mlj
73111

74-
@test all(Maxnet.predict(model, env_typed) .≈ mlj_true_probability)
112+
@test all(predict(model, env_typed) .≈ mlj_true_probability)
75113

76114
@test Statistics.mean(mlj_true_probability[p_a]) > Statistics.mean(mlj_true_probability[.~p_a])
77115
@test minimum(mlj_true_probability) > 0.

0 commit comments

Comments
 (0)