Skip to content

Commit 6f1b3fe

Browse files
Merge pull request #109 from tidymodels/more-objectives-lightgbm
2 parents 0e7afbc + 2c83f2b commit 6f1b3fe

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# bonsai (development version)
22

3+
* Automatic handling of `num_classes` argument when specifying a multiclass classification objective for the lightgbm engine (#109).
4+
35
* Increased the minimum R version to R 4.1.
46

57
* Fixed bug where `num_threads` argument were ignored for lightgbm engine (#105).

R/lightgbm.R

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ train_lightgbm <- function(
101101
)
102102

103103
args <- process_bagging(args)
104-
args <- process_objective_function(args, x, y)
104+
args <- process_objective_function(args, y)
105105

106106
args <- sort_args(args)
107107

@@ -175,24 +175,40 @@ process_mtry <- function(
175175
feature_fraction_bynode
176176
}
177177

178-
process_objective_function <- function(args, x, y) {
178+
# https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters
179+
process_objective_function <- function(args, y) {
180+
lvl <- levels(y)
181+
lvls <- length(lvl)
179182
# set the "objective" param argument, clear it out from main args
180183
if (!any(names(args) %in% c("objective"))) {
181184
if (is.numeric(y)) {
182185
args$objective <- "regression"
183186
} else {
184-
lvl <- levels(y)
185-
lvls <- length(lvl)
186187
if (lvls == 2) {
187-
args$num_class <- 1
188188
args$objective <- "binary"
189189
} else {
190-
args$num_class <- lvls
191190
args$objective <- "multiclass"
192191
}
193192
}
194193
}
195194

195+
if (args$objective == "binary" && is.null(args$num_class)) {
196+
args$num_class <- 1L
197+
}
198+
199+
multiclass_obj <- c(
200+
"multiclass",
201+
"softmax",
202+
"multiclassova",
203+
"multiclass_ova",
204+
"ova",
205+
"ovr"
206+
)
207+
208+
if (args$objective %in% multiclass_obj && is.null(args$num_class)) {
209+
args$num_class <- lvls
210+
}
211+
196212
args
197213
}
198214

tests/testthat/test-lightgbm.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,30 @@ test_that("bonsai correctly determines objective when label is a factor", {
348348
expect_equal(bst$params$num_class, 3)
349349
})
350350

351+
test_that("bonsai correctly determines num_classes when objective is set", {
352+
skip_if_not_installed("lightgbm")
353+
skip_if_not_installed("modeldata")
354+
355+
suppressPackageStartupMessages({
356+
library(lightgbm)
357+
library(dplyr)
358+
})
359+
360+
data("penguins", package = "modeldata")
361+
penguins <- penguins[complete.cases(penguins), ]
362+
363+
expect_no_error({
364+
bst <- train_lightgbm(
365+
x = penguins[, c("bill_length_mm", "bill_depth_mm")],
366+
y = penguins[["species"]],
367+
num_iterations = 5,
368+
verbose = -1L,
369+
objective = "multiclassova"
370+
)
371+
})
372+
expect_identical(bst$params$objective, "multiclassova")
373+
expect_identical(bst$params$num_class, 3L)
374+
})
351375

352376
test_that("bonsai handles mtry vs mtry_prop gracefully", {
353377
skip_if_not_installed("modeldata")

0 commit comments

Comments
 (0)