Skip to content

Commit 143614c

Browse files
Fixed bug where coef() didn't would error if used on a brulee_logistic_reg() that was trained with a recipe (#67)
* extract rows instead of columns in coef.brulee_logistic_reg() * add news * add test --------- Co-authored-by: Max Kuhn <[email protected]>
1 parent 3424cf1 commit 143614c

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# brulee (development version)
22

3+
* Fixed bug where `coef()` didn't would error if used on a `brulee_logistic_reg()` that was trained with a recipe. (#66)
4+
35
* Fixed a bug where SGD always being used as the optimizer (#61).
46

7+
58
# brulee 0.2.0
69

710
* Several learning rate schedulers were added to the modeling functions (#12).

R/coef.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ brulee_coefs <- function(object, epoch = NULL, ...) {
5858
#' @export
5959
coef.brulee_logistic_reg <- function(object, epoch = NULL, ...) {
6060
network_params <- brulee_coefs(object, epoch)
61-
slopes <- network_params$fc1.weight[,2] - network_params$fc1.weight[,1]
61+
slopes <- network_params$fc1.weight[2, ] - network_params$fc1.weight[1, ]
6262
int <- network_params$fc1.bias[2] - network_params$fc1.bias[1]
6363
param <- c(int, slopes)
6464
names(param) <- c("(Intercept)", object$dims$features)

tests/testthat/test-logistic_reg-fit.R

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,35 @@ test_that("basic logistic regression SGD", {
112112
expect_true(bin_brier_sgd$.estimate < (1 - 1/num_class)^2)
113113
})
114114

115+
test_that("coef works when recipes are used", {
116+
skip_if_not(torch::torch_is_installed())
117+
skip_if_not_installed("modeldata")
118+
skip_if_not_installed("recipes")
119+
skip_if(packageVersion("rlang") < "1.0.0")
120+
skip_on_os(c("windows", "linux", "solaris"))
121+
122+
data("lending_club", package = "modeldata")
123+
lending_club <- head(lending_club, 1000)
124+
125+
rec <-
126+
recipes::recipe(Class ~ revol_util + open_il_24m + emp_length,
127+
data = lending_club) %>%
128+
recipes::step_dummy(emp_length, one_hot = TRUE) %>%
129+
recipes::step_normalize(recipes::all_predictors())
130+
131+
fit_rec <- brulee_logistic_reg(rec, lending_club, epochs = 10L)
132+
133+
coefs <- coef(fit_rec)
134+
expect_true(all(is.numeric(coefs)))
135+
expect_identical(
136+
names(coefs),
137+
c(
138+
"(Intercept)", "revol_util", "open_il_24m",
139+
paste0("emp_length_", levels(lending_club$emp_length))
140+
)
141+
)
142+
})
143+
115144

116145
# ------------------------------------------------------------------------------
117146

0 commit comments

Comments
 (0)