Skip to content

Commit dbe666f

Browse files
committed
Merge branch 'main' into sparse-glmnet-predict
2 parents ba56e2a + 53263e9 commit dbe666f

File tree

216 files changed

+1909
-1126
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

216 files changed

+1909
-1126
lines changed

DESCRIPTION

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9002
3+
Version: 1.2.1.9004
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
@@ -25,23 +25,23 @@ Imports:
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 1.4.0),
28+
hardhat (>= 1.4.0.9003),
2929
lifecycle,
3030
magrittr,
3131
pillar,
3232
prettyunits,
3333
purrr (>= 1.0.0),
3434
rlang (>= 1.1.0),
35-
sparsevctrs (>= 0.1.0.9000),
35+
sparsevctrs (>= 0.1.0.9002),
3636
stats,
3737
tibble (>= 2.1.1),
3838
tidyr (>= 1.3.0),
3939
utils,
4040
vctrs (>= 0.6.0),
4141
withr
4242
Suggests:
43-
C50,
4443
bench,
44+
C50,
4545
covr,
4646
dials (>= 1.1.0),
4747
earth,
@@ -69,16 +69,17 @@ Suggests:
6969
xgboost (>= 1.5.0.1)
7070
VignetteBuilder:
7171
knitr
72+
Remotes:
73+
r-lib/sparsevctrs,
74+
tidymodels/hardhat
7275
ByteCompile: true
7376
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
74-
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
75-
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
77+
LiblineaR, mgcv, nnet, parsnip, quantreg, randomForest, ranger, rpart,
78+
rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
7679
xgboost
7780
Config/rcmdcheck/ignore-inconsequential-notes: true
7881
Config/testthat/edition: 3
7982
Encoding: UTF-8
8083
LazyData: true
8184
Roxygen: list(markdown = TRUE)
82-
Remotes:
83-
r-lib/sparsevctrs
8485
RoxygenNote: 7.3.2

NAMESPACE

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ export(.dat)
185185
export(.extract_surv_status)
186186
export(.extract_surv_time)
187187
export(.facts)
188+
export(.get_prediction_column_names)
188189
export(.lvls)
189190
export(.model_param_name_key)
190191
export(.obs)
@@ -202,7 +203,6 @@ export(bag_mars)
202203
export(bag_mlp)
203204
export(bag_tree)
204205
export(bart)
205-
export(bartMachine_interval_calc)
206206
export(boost_tree)
207207
export(case_weights_allowed)
208208
export(cforest_train)
@@ -264,6 +264,7 @@ export(make_classes)
264264
export(make_engine_list)
265265
export(make_seealso_list)
266266
export(mars)
267+
export(matrix_to_quantile_pred)
267268
export(max_mtry_formula)
268269
export(maybe_data_frame)
269270
export(maybe_matrix)
@@ -311,7 +312,6 @@ export(rand_forest)
311312
export(repair_call)
312313
export(req_pkgs)
313314
export(required_pkgs)
314-
export(rpart_train)
315315
export(rule_fit)
316316
export(set_args)
317317
export(set_dependency)
@@ -377,6 +377,7 @@ importFrom(generics,tune_args)
377377
importFrom(generics,varying_args)
378378
importFrom(ggplot2,autoplot)
379379
importFrom(glue,glue_collapse)
380+
importFrom(hardhat,contr_one_hot)
380381
importFrom(hardhat,extract_fit_engine)
381382
importFrom(hardhat,extract_fit_time)
382383
importFrom(hardhat,extract_parameter_dials)

NEWS.md

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,51 @@
11
# parsnip (development version)
22

3+
## New Features
4+
5+
* A new model mode (`"quantile regression"`) was added. Including:
6+
* A `linear_reg()` engine for `"quantreg"`.
7+
* Predictions are encoded via a custom vector type. See [hardhat::quantile_pred()].
8+
* Predicted quantile levels are designated when the new mode is specified. See `?set_mode`.
9+
310
* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).
411

512
* `fit_xy()` can now take sparse tibbles as data values (#1165).
613

714
* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).
815

9-
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
10-
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
11-
#1161, #1081).
16+
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
17+
18+
## Other Changes
19+
20+
* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).
1221

1322
* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).
1423

1524
* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
1625

17-
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
26+
* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#1219).
27+
28+
* The deprecated function `rpart_train()` was removed after its deprecation period (#1044).
29+
30+
## Bug Fixes
31+
32+
* Make sure that parsnip does not convert ordered factor predictions to be unordered.
1833

1934
* Ensure that `knit_engine_docs()` has the required packages installed (#1156).
2035

2136
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
2237

38+
* Fixed bug related to using local (non-package) models (#1229)
39+
40+
* `tunable()` now references a dials object for the `mixture` parameter (#1236)
41+
42+
## Breaking Change
43+
44+
* For quantile prediction, the `quantile` argument to `predict()` has been deprecate in facor of `quantile_levels`. This does not affect models with mode `"quantile regression"`.
45+
46+
* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.
47+
48+
2349
# parsnip 1.2.1
2450

2551
* Added a missing `tidy()` method for survival analysis glmnet models (#1086).

R/aaa-import-standalone-types-check.R

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# Standalone file: do not edit by hand
2-
# Source: <https://github.com/r-lib/rlang/blob/main/R/standalone-types-check.R>
3-
# ----------------------------------------------------------------------
4-
#
51
# ---
62
# repo: r-lib/rlang
73
# file: standalone-types-check.R
@@ -13,6 +9,9 @@
139
#
1410
# ## Changelog
1511
#
12+
# 2024-08-15:
13+
# - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724)
14+
#
1615
# 2023-03-13:
1716
# - Improved error messages of number checkers (@teunbrand)
1817
# - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich).
@@ -461,15 +460,28 @@ check_formula <- function(x,
461460

462461
# Vectors -----------------------------------------------------------------
463462

463+
# TODO: Figure out what to do with logical `NA` and `allow_na = TRUE`
464+
464465
check_character <- function(x,
465466
...,
467+
allow_na = TRUE,
466468
allow_null = FALSE,
467469
arg = caller_arg(x),
468470
call = caller_env()) {
471+
469472
if (!missing(x)) {
470473
if (is_character(x)) {
474+
if (!allow_na && any(is.na(x))) {
475+
abort(
476+
sprintf("`%s` can't contain NA values.", arg),
477+
arg = arg,
478+
call = call
479+
)
480+
}
481+
471482
return(invisible(NULL))
472483
}
484+
473485
if (allow_null && is_null(x)) {
474486
return(invisible(NULL))
475487
}
@@ -479,7 +491,6 @@ check_character <- function(x,
479491
x,
480492
"a character vector",
481493
...,
482-
allow_na = FALSE,
483494
allow_null = allow_null,
484495
arg = arg,
485496
call = call

R/aaa_models.R

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Initialize model environments
22

3-
all_modes <- c("classification", "regression", "censored regression")
3+
all_modes <- c("classification", "regression", "censored regression", "quantile regression")
44

55
# ------------------------------------------------------------------------------
66

@@ -195,8 +195,8 @@ stop_missing_engine <- function(cls, call) {
195195
}
196196

197197
check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) {
198-
all_modes <- get_from_env(paste0(cls, "_modes"))
199-
if (!(mode %in% all_modes)) {
198+
model_modes <- get_from_env(paste0(cls, "_modes"))
199+
if (!(mode %in% model_modes)) {
200200
cli::cli_abort(
201201
"{.val {mode}} is not a known mode for model {.fn {cls}}.",
202202
call = call
@@ -796,7 +796,7 @@ is_discordant_info <- function(model, mode, eng, candidate,
796796
if (component == "predict" & !is.null(pred_type)) {
797797

798798
current <- dplyr::filter(current, type == pred_type)
799-
p_type <- paste0("and prediction type '", pred_type, "'")
799+
p_type <- "and prediction type {.val {pred_type}} "
800800
} else {
801801
p_type <- ""
802802
}
@@ -809,9 +809,12 @@ is_discordant_info <- function(model, mode, eng, candidate,
809809

810810
if (!same_info) {
811811
cli::cli_abort(
812-
"The combination of engine {.var {eng}} and mode {.var {mode}} \\
813-
{.val {p_type}} already has {component} data for model {.var {model}} \\
814-
and the new information being registered is different.",
812+
paste0(
813+
"The combination of engine {.var {eng}} and mode {.var {mode}} ",
814+
p_type,
815+
"already has {component} data for model {.var {model}}
816+
and the new information being registered is different."
817+
),
815818
call = call
816819
)
817820
}

R/aaa_quantiles.R

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#' Reformat quantile predictions
2+
#'
3+
#' @param x A matrix of predictions with rows as samples and columns as quantile
4+
#' levels.
5+
#' @param object A parsnip `model_fit` object from a quantile regression model.
6+
#' @keywords internal
7+
#' @export
8+
matrix_to_quantile_pred <- function(x, object) {
9+
if (!is.matrix(x)) {
10+
x <- as.matrix(x)
11+
}
12+
rownames(x) <- NULL
13+
n_pred_quantiles <- ncol(x)
14+
quantile_levels <- object$spec$quantile_levels
15+
16+
tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)))
17+
}

R/arguments.R

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ check_eng_args <- function(args, obj, core_args) {
2121
common_args <- intersect(protected_args, names(args))
2222
if (length(common_args) > 0) {
2323
args <- args[!(names(args) %in% common_args)]
24-
common_args <- paste0(common_args, collapse = ", ")
2524
cli::cli_warn(
26-
"The argument{?s} {.arg {common_args}} cannot be manually
27-
modified and {?was/were} removed."
25+
c(
26+
"The argument{?s} {.arg {common_args}} cannot be manually modified
27+
and {?was/were} removed."
28+
),
29+
class = "parsnip_protected_arg_warning"
2830
)
2931
}
3032
args
@@ -49,6 +51,8 @@ check_eng_args <- function(args, obj, core_args) {
4951
#' set_args(mtry = 3, importance = TRUE) %>%
5052
#' set_mode("regression")
5153
#'
54+
#' linear_reg() %>%
55+
#' set_mode("quantile regression", quantile_levels = c(0.2, 0.5, 0.8))
5256
#' @export
5357
set_args <- function(object, ...) {
5458
UseMethod("set_args")
@@ -89,12 +93,18 @@ set_args.default <- function(object,...) {
8993

9094
#' @rdname set_args
9195
#' @export
92-
set_mode <- function(object, mode) {
96+
set_mode <- function(object, mode, ...) {
9397
UseMethod("set_mode")
9498
}
9599

100+
#' @rdname set_args
101+
#' @param quantile_levels A vector of values between zero and one (only for the
102+
#' `"quantile regression"` mode); otherwise, it is `NULL`. The model uses these
103+
#' values to appropriately train quantile regression models to make predictions
104+
#' for these values (e.g., `quantile_levels = 0.5` is the median).
96105
#' @export
97-
set_mode.model_spec <- function(object, mode) {
106+
set_mode.model_spec <- function(object, mode, quantile_levels = NULL, ...) {
107+
check_dots_empty()
98108
cls <- class(object)[1]
99109
if (rlang::is_missing(mode)) {
100110
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
@@ -111,11 +121,21 @@ set_mode.model_spec <- function(object, mode) {
111121

112122
object$mode <- mode
113123
object$user_specified_mode <- TRUE
124+
if (mode == "quantile regression") {
125+
hardhat::check_quantile_levels(quantile_levels)
126+
} else {
127+
if (!is.null(quantile_levels)) {
128+
cli::cli_warn("{.arg quantile_levels} is only used when the mode is
129+
{.val quantile regression}.")
130+
}
131+
}
132+
133+
object$quantile_levels <- quantile_levels
114134
object
115135
}
116136

117137
#' @export
118-
set_mode.default <- function(object, mode) {
138+
set_mode.default <- function(object, mode, ...) {
119139
error_set_object(object, func = "set_mode")
120140

121141
invisible(FALSE)
@@ -240,7 +260,7 @@ make_form_call <- function(object, env = NULL) {
240260
}
241261

242262
# TODO we need something to indicate that case weights are being used.
243-
make_xy_call <- function(object, target, env) {
263+
make_xy_call <- function(object, target, env, call = rlang::caller_env()) {
244264
fit_args <- object$method$fit$args
245265
uses_weights <- has_weights(env)
246266

@@ -265,7 +285,7 @@ make_xy_call <- function(object, target, env) {
265285
data.frame = rlang::expr(maybe_data_frame(x)),
266286
matrix = rlang::expr(maybe_matrix(x)),
267287
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
268-
cli::cli_abort("Invalid data type target: {target}.")
288+
cli::cli_abort("Invalid data type target: {target}.", call = call)
269289
)
270290
if (uses_weights) {
271291
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)

R/autoplot.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
4141
}
4242

4343

44-
map_glmnet_coefs <- function(x) {
44+
map_glmnet_coefs <- function(x, call = rlang::caller_env()) {
4545
coefs <- coef(x)
4646
# If parsnip is used to fit the model, glmnet should be attached and this will
4747
# work. If an object is loaded from a new session, they will need to load the
4848
# package.
4949
if (is.null(coefs)) {
5050
cli::cli_abort(
51-
"Please load the {.pkg glmnet} package before running {.fun autoplot}."
51+
"Please load the {.pkg glmnet} package before running {.fun autoplot}.",
52+
call = call
5253
)
5354
}
5455
p <- x$dim[1]
@@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) {
8990
dplyr::slice(seq_len(top_n))
9091
}
9192

92-
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
93+
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
94+
call = rlang::caller_env(), ...) {
9395
tidy_coefs <-
94-
map_glmnet_coefs(x) %>%
96+
map_glmnet_coefs(x, call = call) %>%
9597
dplyr::filter(penalty >= min_penalty)
9698

9799
actual_min_penalty <- min(tidy_coefs$penalty)

0 commit comments

Comments
 (0)