Skip to content

Commit c2a1a9e

Browse files
Merge pull request #111 from tidymodels/catboost
2 parents 67a9562 + 89c9d7f commit c2a1a9e

File tree

13 files changed

+940
-2
lines changed

13 files changed

+940
-2
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
env:
4040
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
4141
R_KEEP_PKG_SOURCE: yes
42+
GHA_OS: ${{ matrix.config.os }}
4243

4344
steps:
4445
- uses: actions/checkout@v4
@@ -56,6 +57,25 @@ jobs:
5657
extra-packages: any::rcmdcheck
5758
needs: check
5859

60+
- name: Install catboost
61+
run: |
62+
install.packages('remotes')
63+
os <- Sys.getenv("GHA_OS")
64+
if (os == "macos-latest") {
65+
name <- "Darwin"
66+
} else if (os == "windows-latest") {
67+
name <- "Windows"
68+
} else {
69+
name <- "Linux"
70+
}
71+
url <- paste0(
72+
'https://github.com/catboost/catboost/releases/download/v1.2.8/catboost-R-',
73+
name,
74+
'-1.2.8.tgz'
75+
)
76+
remotes::install_url(url, INSTALL_opts = c("--no-multiarch", "--no-test-load", "--no-staged-install"))
77+
shell: Rscript {0}
78+
5979
- uses: r-lib/actions/check-r-package@v2
6080
with:
6181
upload-snapshots: true

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
S3method(multi_predict,"_lgb.Booster")
44
export("%>%")
5+
export(predict_catboost_classification_class)
6+
export(predict_catboost_classification_prob)
7+
export(predict_catboost_classification_raw)
8+
export(predict_catboost_regression_numeric)
59
export(predict_lightgbm_classification_class)
610
export(predict_lightgbm_classification_prob)
711
export(predict_lightgbm_classification_raw)
812
export(predict_lightgbm_regression_numeric)
13+
export(train_catboost)
914
export(train_lightgbm)
1015
import(rlang)
1116
importFrom(dials,min_n)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
* Fixed bug where `num_threads` argument were ignored for lightgbm engine (#105).
1010

11+
* Added catboost engine to `boost_tree()` (#70).
12+
1113
# bonsai 0.3.2
1214

1315
* Resolves a test failure ahead of an upcoming parsnip release (#95).

R/catboost.R

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#' Boosted trees with catboost
2+
#'
3+
#' `train_catboost` is a wrapper for `catboost` tree-based models
4+
#' where all of the model arguments are in the main function.
5+
#'
6+
#' This is an internal function, not meant to be directly called by the user.
7+
#'
8+
#' @param x A data frame of predictors.
9+
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
10+
#' @param weights A numeric vector of sample weights, defaults to `NULL`.
11+
#' @param iterations The maximum number of trees that can be built when solving
12+
#' machine learning problems. Default to 1000.
13+
#' @param learning_rate A positive numeric value for the learning rate. Defaults
14+
#' to 0.03.
15+
#' @param depth An integer for the depth of the trees. Default to 6.
16+
#' @param l2_leaf_reg A numeric value for the L2 regularization coefficient.
17+
#' Used for leaf value calculation. Defaults to 3.
18+
#' @param random_strength The amount of randomness to use for scoring splits
19+
#' when the tree structure is selected. Use this parameter to avoid overfitting
20+
#' the model. Defaults to 1.
21+
#' @param bagging_temperature A numeric value, controls intensity of Bayesian
22+
#' bagging. The higher the temperature the more aggressive bagging is. Defaults
23+
#' to 1.
24+
#' @param rsm A numeric value between 0 and 1, random subspace method. The
25+
#' percentage of features to use at each iteration of building trees. At each
26+
#' iteration, features are selected over again at random. Defaults to 1.
27+
#' @param quiet A logical; should logging by [catboost::catboost.train()] be
28+
#' muted?
29+
#' @param ... Other options to pass to [catboost::catboost.train()]. Arguments
30+
#' will be correctly routed to the `param` argument, or as a main argument,
31+
#' depending on their name.
32+
#'
33+
#' @source \url{https://catboost.ai/docs/en/references/training-parameters/}.
34+
#'
35+
#' @return A fitted `catboost.Model` object.
36+
#' @keywords internal
37+
#' @export
38+
train_catboost <- function(
39+
x,
40+
y,
41+
weights = NULL,
42+
iterations = 1000,
43+
learning_rate = 0.03,
44+
depth = 6,
45+
l2_leaf_reg = 3,
46+
random_strength = 1,
47+
bagging_temperature = 1,
48+
rsm = 1,
49+
quiet = TRUE,
50+
...
51+
) {
52+
force(x)
53+
force(y)
54+
55+
call <- call2("fit")
56+
57+
check_number_whole(iterations, call = call)
58+
check_number_decimal(learning_rate, call = call)
59+
check_number_whole(depth, call = call)
60+
check_number_decimal(l2_leaf_reg, call = call)
61+
check_number_decimal(random_strength, call = call)
62+
check_number_decimal(bagging_temperature, call = call)
63+
check_number_decimal(rsm, call = call)
64+
check_bool(quiet, call = call)
65+
66+
arg_params <- list(
67+
iterations = iterations,
68+
learning_rate = learning_rate,
69+
depth = depth,
70+
l2_leaf_reg = l2_leaf_reg,
71+
random_strength = random_strength,
72+
bagging_temperature = bagging_temperature,
73+
rsm = rsm,
74+
...
75+
)
76+
77+
arg_params <- process_loss_function(arg_params, y)
78+
79+
if (!is.null(arg_params$params) && is.list(arg_params$params)) {
80+
cli::cli_warn(c(
81+
"Arguments passed in through {.arg params} as a list will be ignored.",
82+
"Instead pass the arguments directly to the {.code ...}."
83+
))
84+
arg_params$params <- NULL
85+
}
86+
87+
learn_pool <- rlang::call2(
88+
"catboost.load_pool",
89+
data = x,
90+
label = y,
91+
weight = weights,
92+
.ns = "catboost"
93+
)
94+
learn_pool <- rlang::eval_tidy(learn_pool, env = rlang::current_env())
95+
96+
args <- list(
97+
learn_pool = learn_pool,
98+
params = arg_params
99+
)
100+
101+
call <- rlang::call2("catboost.train", !!!args, .ns = "catboost")
102+
103+
if (quiet) {
104+
junk <- utils::capture.output(
105+
res <- rlang::eval_tidy(call, env = rlang::current_env())
106+
)
107+
} else {
108+
res <- rlang::eval_tidy(call, env = rlang::current_env())
109+
}
110+
111+
res
112+
}
113+
114+
#' Internal functions
115+
#'
116+
#' Not intended for direct use.
117+
#'
118+
#' @keywords internal
119+
#' @export
120+
#' @rdname catboost_helpers
121+
predict_catboost_regression_numeric <- function(object, new_data, ...) {
122+
pool <- rlang::eval_tidy(rlang::call2(
123+
"catboost.load_pool",
124+
data = new_data,
125+
.ns = "catboost"
126+
))
127+
128+
p <- rlang::eval_tidy(rlang::call2(
129+
"catboost.predict",
130+
model = object$fit,
131+
pool = pool,
132+
.ns = "catboost"
133+
))
134+
p
135+
}
136+
137+
#' @keywords internal
138+
#' @export
139+
#' @rdname catboost_helpers
140+
predict_catboost_classification_class <- function(object, new_data, ...) {
141+
pool <- rlang::eval_tidy(rlang::call2(
142+
"catboost.load_pool",
143+
data = new_data,
144+
.ns = "catboost"
145+
))
146+
147+
p <- rlang::eval_tidy(rlang::call2(
148+
"catboost.predict",
149+
model = object$fit,
150+
pool = pool,
151+
prediction_type = "Class",
152+
.ns = "catboost"
153+
))
154+
155+
object$lvl[p + 1]
156+
}
157+
158+
#' @keywords internal
159+
#' @export
160+
#' @rdname catboost_helpers
161+
predict_catboost_classification_prob <- function(object, new_data, ...) {
162+
pool <- rlang::eval_tidy(rlang::call2(
163+
"catboost.load_pool",
164+
data = new_data,
165+
.ns = "catboost"
166+
))
167+
168+
p <- rlang::eval_tidy(rlang::call2(
169+
"catboost.predict",
170+
model = object$fit,
171+
pool = pool,
172+
prediction_type = "Probability",
173+
.ns = "catboost"
174+
))
175+
176+
if (is.vector(p)) {
177+
p <- tibble::tibble(p1 = 1 - p, p2 = p)
178+
}
179+
180+
colnames(p) <- object$lvl
181+
182+
tibble::as_tibble(p)
183+
}
184+
185+
#' @keywords internal
186+
#' @export
187+
#' @rdname catboost_helpers
188+
predict_catboost_classification_raw <- function(object, new_data, ...) {
189+
pool <- rlang::eval_tidy(rlang::call2(
190+
"catboost.load_pool",
191+
data = new_data,
192+
.ns = "catboost"
193+
))
194+
195+
p <- rlang::eval_tidy(rlang::call2(
196+
"catboost.predict",
197+
model = object$fit,
198+
pool = pool,
199+
.ns = "catboost"
200+
))
201+
p
202+
}
203+
204+
# https://catboost.ai/docs/en/concepts/loss-functions
205+
process_loss_function <- function(args, y) {
206+
lvl <- levels(y)
207+
lvls <- length(lvl)
208+
# set the "loss_function" param argument, clear it out from main args
209+
if (!any(names(args) %in% c("loss_function"))) {
210+
if (is.numeric(y)) {
211+
args$loss_function <- "RMSE"
212+
} else {
213+
if (lvls == 2) {
214+
args$loss_function <- "Logloss"
215+
} else {
216+
args$loss_function <- "MultiClass"
217+
}
218+
}
219+
}
220+
221+
args
222+
}

0 commit comments

Comments
 (0)