Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# parsnip (development version)

* Updates to some boosting tuning parameter information:
- lightgbm and catboost have smaller default ranges for the learning rate: -3 to -1 / 2 in log10 units.
- lightgbm, xgboost, catboost, and C5.0 have smaller default ranges for the sampling proportion: 0.5 to 1.0.
- catboost engine arguments were added for `max_leaves` and `l2_leaf_reg`.

* Enable generalized random forest (`grf`) models for classification, regression, and quantile regression modes. (#1288)

* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).
Expand Down
29 changes: 26 additions & 3 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ lightgbm_engine_args <-
component_id = "engine"
)

catboost_engine_args <-
tibble::tibble(
name = c(
"max_leaves",
"l2_leaf_reg"
),
call_info = list(
list(pkg = "dials", fun = "num_leaves"),
list(pkg = "dials", fun = "penalty", range = c(-4, 1))
),
source = "model_spec",
component = "boost_tree",
component_id = "engine"
)

ranger_engine_args <-
tibble::tibble(
name = c(
Expand Down Expand Up @@ -345,19 +360,27 @@ tunable.boost_tree <- function(x, ...) {
if (x$engine == "xgboost") {
res <- add_engine_parameters(res, xgboost_engine_args)
res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop"))
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
res$call_info[res$name == "learn_rate"] <-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
} else if (x$engine == "C5.0") {
res <- add_engine_parameters(res, c5_boost_engine_args)
res$call_info[res$name == "trees"] <-
list(list(pkg = "dials", fun = "trees", range = c(1, 100)))
res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop"))
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
} else if (x$engine == "lightgbm") {
res <- add_engine_parameters(res, lightgbm_engine_args)
res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop"))
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
res$call_info[res$name == "learn_rate"] <-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
} else if (x$engine == "catboost") {
res <- add_engine_parameters(res, catboost_engine_args)
res$call_info[res$name == "learn_rate"] <-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
}
res
}
Expand Down