Skip to content

Commit 0a453ce

Browse files
Make work with all versions of xgboost (#119)
1 parent 007902c commit 0a453ce

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# orbital (development version)
22

3+
* Make work with new versions of xgboost. (#119)
4+
35
# orbital 0.4.0
46

57
* Added support for tailor package and its integration into workflows. The following adjustments have gained `orbital()` support. (#103)

R/model-xgboost.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ orbital.xgb.Booster <- function(
1010
type <- default_type(type)
1111

1212
if (mode == "classification") {
13-
objective <- x$params$objective
13+
objective <- x$params$objective %||% attr(x, "params")$objective
1414
objective <- rlang::arg_match0(
1515
objective,
1616
c("multi:softprob", "binary:logistic")
@@ -32,7 +32,10 @@ orbital.xgb.Booster <- function(
3232
xgboost_multisoft <- function(x, type, lvl) {
3333
trees <- tidypredict::.extract_xgb_trees(x)
3434

35-
trees_split <- split(trees, rep(seq_along(lvl), x$niter))
35+
trees_split <- split(
36+
trees,
37+
rep(seq_along(lvl), x$niter %||% nrow(attr(x, "evaluation_log")))
38+
)
3639
trees_split <- lapply(trees_split, collapse_stumps)
3740
trees_split <- vapply(trees_split, paste, character(1), collapse = " + ")
3841

tests/testthat/test-model-xgboost.R

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
test_that("boost_tree(), objective = reg:squarederror, works with type = numeric", {
2+
skip_if_not_installed("parsnip")
3+
skip_if_not_installed("tidypredict")
4+
skip_if_not_installed("xgboost")
5+
6+
bt_spec <- parsnip::boost_tree(mode = "regression", engine = "xgboost")
7+
8+
bt_fit <- parsnip::fit(bt_spec, mpg ~ disp + vs + hp, mtcars)
9+
10+
orb_obj <- orbital(bt_fit)
11+
12+
# to avoid exact split values
13+
mtcars <- mtcars + 0.1
14+
15+
preds <- predict(orb_obj, mtcars)
16+
exps <- predict(bt_fit, mtcars)
17+
18+
expect_named(preds, ".pred")
19+
expect_type(preds$.pred, "double")
20+
21+
exps <- as.data.frame(exps)
22+
23+
rownames(preds) <- NULL
24+
rownames(exps) <- NULL
25+
26+
expect_equal(
27+
preds,
28+
exps,
29+
tolerance = 0.0000001
30+
)
31+
})
32+
133
test_that("boost_tree(), objective = binary:logistic, works with type = class", {
234
skip_if_not_installed("parsnip")
335
skip_if_not_installed("tidypredict")
@@ -11,6 +43,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = class",
1143

1244
orb_obj <- orbital(bt_fit, type = "class")
1345

46+
# to avoid exact split values
47+
mtcars[, -8] <- mtcars[, -8] + 0.1
48+
1449
preds <- predict(orb_obj, mtcars)
1550
exps <- predict(bt_fit, mtcars)
1651

@@ -23,7 +58,7 @@ test_that("boost_tree(), objective = binary:logistic, works with type = class",
2358
)
2459
})
2560

26-
test_that("boost_tree(), objective = binary:logistic, works with type = class", {
61+
test_that("boost_tree(), objective = multi:softprob, works with type = class", {
2762
skip_if_not_installed("parsnip")
2863
skip_if_not_installed("tidypredict")
2964
skip_if_not_installed("xgboost")
@@ -34,6 +69,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = class",
3469

3570
orb_obj <- orbital(bt_fit, type = "class")
3671

72+
# to avoid exact split values
73+
iris[, -5] <- iris[, -5] + 0.05
74+
3775
preds <- predict(orb_obj, iris)
3876
exps <- predict(bt_fit, iris)
3977

@@ -59,6 +97,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
5997

6098
orb_obj <- orbital(bt_fit, type = "prob")
6199

100+
# to avoid exact split values
101+
mtcars[, -8] <- mtcars[, -8] + 0.1
102+
62103
preds <- predict(orb_obj, mtcars)
63104
exps <- predict(bt_fit, mtcars, type = "prob")
64105

@@ -78,7 +119,7 @@ test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
78119
)
79120
})
80121

81-
test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
122+
test_that("boost_tree(), objective = multi:softprob, works with type = prob", {
82123
skip_if_not_installed("parsnip")
83124
skip_if_not_installed("tidypredict")
84125
skip_if_not_installed("xgboost")
@@ -89,6 +130,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
89130

90131
orb_obj <- orbital(bt_fit, type = "prob")
91132

133+
# to avoid exact split values
134+
iris[, -5] <- iris[, -5] + 0.05
135+
92136
preds <- predict(orb_obj, iris)
93137
exps <- predict(bt_fit, iris, type = "prob")
94138

@@ -122,6 +166,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = c(class,
122166

123167
orb_obj <- orbital(bt_fit, type = c("class", "prob"))
124168

169+
# to avoid exact split values
170+
mtcars[, -8] <- mtcars[, -8] + 0.1
171+
125172
preds <- predict(orb_obj, mtcars)
126173
exps <- dplyr::bind_cols(
127174
predict(bt_fit, mtcars, type = c("class")),
@@ -146,7 +193,7 @@ test_that("boost_tree(), objective = binary:logistic, works with type = c(class,
146193
)
147194
})
148195

149-
test_that("boost_tree(), objective = binary:logistic, works with type = c(class, prob)", {
196+
test_that("boost_tree(), objective = multi:softprob, works with type = c(class, prob)", {
150197
skip_if_not_installed("parsnip")
151198
skip_if_not_installed("tidypredict")
152199
skip_if_not_installed("xgboost")
@@ -157,6 +204,9 @@ test_that("boost_tree(), objective = binary:logistic, works with type = c(class,
157204

158205
orb_obj <- orbital(bt_fit, type = c("class", "prob"))
159206

207+
# to avoid exact split values
208+
iris[, -5] <- iris[, -5] + 0.05
209+
160210
preds <- predict(orb_obj, iris)
161211
exps <- dplyr::bind_cols(
162212
predict(bt_fit, iris, type = c("class")),

0 commit comments

Comments
 (0)