Skip to content

Commit 1742c96

Browse files
fully implement MLJ (#22)
* add mlj docstring * test with MLJTestInterface * throw a helpful error if input data only has one class * mljtestinterface is not a dep (oops) * move allequal error to main function * fix allequal error * fix tests * add MLJBase as docs dep * fix mlj doctest * attempt fix of multiclass printing * use @example instead of jldoctest * test for no failures in mlj interface test
1 parent 73daa4d commit 1742c96

File tree

5 files changed

+41
-27
lines changed

5 files changed

+41
-27
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ julia = "1.9"
3636
[extras]
3737
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
3838
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
39+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
3940
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4041

4142
[targets]
42-
test = ["DelimitedFiles", "MLJBase", "Test"]
43+
test = ["DelimitedFiles", "MLJBase", "MLJTestInterface", "Test"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
34
Maxnet = "81f79f80-22f2-4e41-ab86-00c11cf0f26f"

src/maxnet_function.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ function maxnet(
4949
n_knots::Int = 50,
5050
kw...)
5151

52+
if allequal(presences)
53+
pa = first(presences) ? "presences" : "absences"
54+
throw(ArgumentError("All data points are $pa. Maxnet will only work with at least some presences and some absences."))
55+
end
56+
5257
_maxnet(
5358
presences,
5459
predictors,

src/mlj_interface.jl

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,6 @@ function MaxnetBinaryClassifier(;
2424
)
2525
end
2626

27-
"""
28-
MaxnetBinaryClassifier
29-
30-
A model type for fitting a maxnet model using `MLJ`.
31-
32-
Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters.
33-
34-
The keywords `link`, and `clamp` are passed to [`Maxnet.predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
35-
See the documentation of these functions for the meaning of these parameters and their defaults.
36-
37-
# Example
38-
```jldoctest
39-
using Maxnet, MLJBase
40-
p_a, env = Maxnet.bradypus()
41-
42-
mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a))
43-
fit!(mach)
44-
yhat = MLJBase.predict(mach, env)
45-
# output
46-
```
47-
48-
"""
49-
MaxnetBinaryClassifier
50-
5127
MMI.metadata_pkg(
5228
MaxnetBinaryClassifier;
5329
name = "Maxnet",
@@ -67,6 +43,26 @@ MMI.metadata_model(
6743
reports_feature_importances=false
6844
)
6945

46+
"""
47+
$(MMI.doc_header(MaxnetBinaryClassifier))
48+
49+
The keywords `link`, and `clamp` are passed to [`predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
50+
See the documentation of these functions for the meaning of these parameters and their defaults.
51+
52+
# Example
53+
```@example
54+
using MLJBase
55+
p_a, env = Maxnet.bradypus()
56+
57+
mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a), scitype_check_level = 0)
58+
fit!(mach, verbosity = 0)
59+
yhat = MLJBase.predict(mach, env)
60+
61+
```
62+
63+
"""
64+
MaxnetBinaryClassifier
65+
7066
function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
7167
# convert categorical to boolean
7268
y_boolean = Bool.(MMI.int(y) .- 1)

test/runtests.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using Maxnet, Test, Statistics, CategoricalArrays
1+
using Maxnet, Statistics, CategoricalArrays, MLJTestInterface
2+
using Test
23

4+
# read in Bradypus data
35
p_a, env = Maxnet.bradypus()
46
# Make the levels in ecoreg string to make sure that that works
57
env = merge(env, (; ecoreg = recode(env.ecoreg, (l => string(l) for l in levels(env.ecoreg))...)))
@@ -82,9 +84,18 @@ end
8284
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
8385
@test m_w.entropy > m.entropy
8486
end
85-
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
8687

8788
@testset "MLJ" begin
89+
data = MLJTestInterface.make_binary()
90+
failures, summary = MLJTestInterface.test(
91+
[MaxnetBinaryClassifier],
92+
data...;
93+
mod=@__MODULE__,
94+
verbosity=0, # bump to debug
95+
throw=false, # set to true to debug
96+
)
97+
@test isempty(failures)
98+
8899
using MLJBase
89100
mn = Maxnet.MaxnetBinaryClassifier
90101

0 commit comments

Comments
 (0)