|
1 | 1 | using Maxnet, Test, MLJBase, Statistics |
2 | 2 |
|
3 | 3 | p_a, env = Maxnet.bradypus() |
4 | | -env1 = map(e -> [e[1]], env) |
| 4 | +env1 = map(e -> [e[1]], env) # just the first row |
5 | 5 |
|
6 | 6 | @testset "utils" begin |
7 | 7 | @test_throws ErrorException Maxnet.features_from_string("a") |
| 8 | + # test each feature class is returned correctly |
8 | 9 | @test Maxnet.features_from_string("l") == [LinearFeature(), CategoricalFeature()] |
9 | 10 | @test Maxnet.features_from_string("q") == [QuadraticFeature()] |
| 11 | + @test Maxnet.features_from_string("lq") == [LinearFeature(), CategoricalFeature(), QuadraticFeature()] |
| 12 | + @test Maxnet.features_from_string("lqp") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature()] |
| 13 | + @test Maxnet.features_from_string("lqph") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature(), HingeFeature()] |
| 14 | + @test Maxnet.features_from_string("lqpt") == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), ProductFeature(), ThresholdFeature()] |
| 15 | + |
| 16 | + @test Maxnet.default_features(100) == [LinearFeature(), CategoricalFeature(), QuadraticFeature(), HingeFeature(), ProductFeature()] |
10 | 17 | end |
11 | 18 |
|
12 | 19 | @testset "Maxnet" begin |
|
18 | 25 | @test all(isapprox.(model_glmnet.coefs, model_lasso.coefs; rtol = 0.1, atol = 0.1)) |
19 | 26 | @test Statistics.cor(model_glmnet.coefs, model_lasso.coefs) > 0.99 |
20 | 27 |
|
21 | | - # select classes automatically |
22 | | - Maxnet.maxnet(p_a, env; backend = LassoBackend()); |
| 28 | + # test the result |
| 29 | + @test model_glmnet.entropy ≈ 6.114650341746531 |
| 30 | + @test complexity(model_glmnet) == 21 |
23 | 31 |
|
24 | | - # some class combinations |
| 32 | + # some class combinations and keywords |
25 | 33 | Maxnet.maxnet(p_a, env; features = "lq", backend = LassoBackend()); |
26 | 34 | Maxnet.maxnet(p_a, env; features = "lqp", regularization_multiplier = 2., backend = LassoBackend()); |
27 | | - Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., backend = LassoBackend()); |
28 | | - Maxnet.maxnet(p_a, env; features = "lqph", backend = LassoBackend()); |
29 | | - Maxnet.maxnet(p_a, env; features = "lqpt", backend = LassoBackend()); |
| 35 | + Maxnet.maxnet(p_a, env; features = "lqh", regularization_multiplier = 5., nknots = 10, backend = LassoBackend()); |
| 36 | + Maxnet.maxnet(p_a, env; features = "lqph", weight_factor = 10., backend = LassoBackend()); |
30 | 37 |
|
31 | 38 | # predictions |
32 | 39 | prediction = Maxnet.predict(model_lasso, env) |
33 | 40 | @test Statistics.mean(prediction[p_a]) > Statistics.mean(prediction[.~p_a]) |
34 | 41 | @test minimum(prediction) > 0. |
35 | 42 | @test maximum(prediction) < 1. |
| 43 | + @test mean(prediction) ≈ 0.243406167194403 atol=1e-4 |
36 | 44 |
|
| 45 | + # check that clamping works |
37 | 46 | # clamp shouldn't change anything in this case |
38 | | - @test all(prediction .== Maxnet.predict(model_lasso, env; clamp = true)) |
| 47 | + @test prediction == Maxnet.predict(model_lasso, env; clamp = true) |
39 | 48 |
|
40 | 49 | # predict with a crazy extrapolation |
41 | 50 | env1_extrapolated = merge(env1, (;cld6190_ann = [100_000])) |
42 | | - # without clamp the prediction is crazy |
43 | | - @test abs(Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink())[1]) > 100_000. |
44 | | - # without clamp the prediction is reasonable |
45 | | - @test abs(Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink(), clamp = true)[1]) < 5. |
| 51 | + env1_max_cld = merge(env1, (;cld6190_ann = [maximum(env.cld6190_ann)])) |
| 52 | + |
| 53 | + # using clamp the prediction uses the highest cloud |
| 54 | + @test Maxnet.predict(model_lasso, env1_extrapolated; link = IdentityLink(), clamp = true) == |
| 55 | + Maxnet.predict(model_lasso, env1_max_cld; link = IdentityLink()) |
46 | 56 | end |
47 | 57 |
|
48 | 58 | @testset "MLJ" begin |
|
59 | 69 | mach2 = machine(mn(features = "lqph", backend = GLMNetBackend()), env_typed, categorical(p_a)) |
60 | 70 | fit!(mach2) |
61 | 71 |
|
| 72 | + # make the equivalent model without mlj |
| 73 | + model = Maxnet.maxnet((p_a), env_typed; features = "lqph", backend = GLMNetBackend()); |
| 74 | + |
| 75 | + |
62 | 76 | # predict via MLJBase |
63 | 77 | mljprediction = MLJBase.predict(mach2, env_typed) |
64 | 78 | mlj_true_probability = pdf.(mljprediction, true) |
65 | 79 |
|
| 80 | + # test that this predicts the same as the equivalent model without mlj |
| 81 | + @test all(Maxnet.predict(model, env_typed) .≈ mlj_true_probability) |
| 82 | + |
66 | 83 | @test Statistics.mean(mlj_true_probability[p_a]) > Statistics.mean(mlj_true_probability[.~p_a]) |
67 | 84 | @test minimum(mlj_true_probability) > 0. |
68 | 85 | @test maximum(mlj_true_probability) < 1. |
|
0 commit comments