Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ export(bag_mars)
export(bag_mlp)
export(bag_tree)
export(bart)
export(bartMachine_interval_calc)
export(boost_tree)
export(case_weights_allowed)
export(cforest_train)
Expand Down
4 changes: 2 additions & 2 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ make_form_call <- function(object, env = NULL) {
}

# TODO we need something to indicate that case weights are being used.
make_xy_call <- function(object, target, env) {
make_xy_call <- function(object, target, env, call = rlang::caller_env()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll just make this comment once to apply throughout the PR, but rlang is imported with @import rlang so we don't need to namespace rlang for these!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It all comes back to psock... we started namespacing everything in these packages because importing was not enough.

fit_args <- object$method$fit$args
uses_weights <- has_weights(env)

Expand All @@ -283,7 +283,7 @@ make_xy_call <- function(object, target, env) {
data.frame = rlang::expr(maybe_data_frame(x)),
matrix = rlang::expr(maybe_matrix(x)),
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
cli::cli_abort("Invalid data type target: {target}.")
cli::cli_abort("Invalid data type target: {target}.", call = call)
)
if (uses_weights) {
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
Expand Down
10 changes: 6 additions & 4 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
}


map_glmnet_coefs <- function(x) {
map_glmnet_coefs <- function(x, call = rlang::caller_env()) {
coefs <- coef(x)
# If parsnip is used to fit the model, glmnet should be attached and this will
# work. If an object is loaded from a new session, they will need to load the
# package.
if (is.null(coefs)) {
cli::cli_abort(
"Please load the {.pkg glmnet} package before running {.fun autoplot}."
"Please load the {.pkg glmnet} package before running {.fun autoplot}.",
call = call
)
}
p <- x$dim[1]
Expand Down Expand Up @@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) {
dplyr::slice(seq_len(top_n))
}

autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
call = rlang::caller_env(), ...) {
tidy_coefs <-
map_glmnet_coefs(x) %>%
map_glmnet_coefs(x, call = call) %>%
dplyr::filter(penalty >= min_penalty)

actual_min_penalty <- min(tidy_coefs$penalty)
Expand Down
48 changes: 0 additions & 48 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,61 +130,13 @@ update.bart <-
)
}


#' Developer functions for predictions via BART models
#' @export
#' @keywords internal
#' @name bart-internal
#' @inheritParams predict.model_fit
#' @param obj A parsnip object.
#' @param ci Confidence (TRUE) or prediction interval (FALSE)
#' @param level Confidence level.
#' @param std_err Attach column for standard error of prediction or not.
bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) {
if (obj$spec$mode == "classification") {
cli::cli_abort(
"Prediction intervals are not possible for classification"
)
}
get_std_err <- obj$spec$method$pred$pred_int$extras$std_error

if (ci) {
cl <-
rlang::call2(
"calc_credible_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
ci_conf = level
)

} else {
cl <-
rlang::call2(
"calc_prediction_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
pi_conf = level
)
}
res <- rlang::eval_tidy(cl)
if (!ci) {
if (get_std_err) {
.std_error <- apply(res$all_prediction_samples, 1, stats::sd, na.rm = TRUE)
}
res <- res$interval
}
res <- tibble::as_tibble(res)
names(res) <- c(".pred_lower", ".pred_upper")
if (!ci & get_std_err) {
res$.std_err <- .std_error
}
res
}

#' @export
#' @rdname bart-internal
#' @keywords internal
dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) {
types <- c("numeric", "class", "prob", "conf_int", "pred_int")
Expand Down
13 changes: 9 additions & 4 deletions R/condense_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#'
#' @return A control object with the same elements and classes of `ref`, with
#' values of `x`.
#' @param call The execution environment of a currently running function, e.g.
#' `caller_env()`. The function will be mentioned in error messages as the
#' source of the error. See the call argument of [rlang::abort()] for more
#' information.
#' @keywords internal
#' @export
#'
Expand All @@ -20,16 +24,17 @@
#'
#' ctrl <- condense_control(ctrl, control_parsnip())
#' str(ctrl)
condense_control <- function(x, ref) {
condense_control <- function(x, ref, ..., call = rlang::caller_env()) {
check_dots_empty()
mismatch <- setdiff(names(ref), names(x))
if (length(mismatch)) {
cli::cli_abort(
c(
"Object of class {.cls class(x)[1]} cannot be coerced to
object of class {.cls class(ref)[1]}.",
"{.obj_type_friendly {x}} cannot be coerced to {.obj_type_friendly {ref}}.",
"i" = "{cli::qty(mismatch)} The argument{?s} {.arg {mismatch}}
{?is/are} missing."
)
),
call = call
)
}
res <- x[names(ref)]
Expand Down
9 changes: 7 additions & 2 deletions R/contr_one_hot.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#' This contrast function produces a model matrix with indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param n A vector of character factor levels (of length >=1) or the number
#' of unique levels (>= 1).
#' @param contrasts This argument is for backwards compatibility and only the
#' default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
Expand All @@ -24,9 +25,13 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
}

if (is.character(n)) {
if (length(n) < 1) {
cli::cli_abort("{.arg n} cannot be empty.")
}
names <- n
n <- length(names)
} else if (is.numeric(n)) {
check_number_whole(n, min = 1)
n <- as.integer(n)

if (length(n) != 1L) {
Expand All @@ -35,7 +40,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {

names <- as.character(seq_len(n))
} else {
cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.")
check_number_whole(n, min = 1)
}

out <- diag(n)
Expand Down
28 changes: 17 additions & 11 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@
na.action = na.omit,
indicators = "traditional",
composition = "data.frame",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame}, {.val matrix}, or
{.val dgCMatrix}."
{.val dgCMatrix}.",
call = call
)
}

if (sparsevctrs::has_sparse_elements(data)) {
cli::cli_abort(
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead."
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead.",
call = call
)
}

Expand Down Expand Up @@ -84,7 +87,7 @@

w <- as.vector(model.weights(mod_frame))
if (!is.null(w) && !is.numeric(w)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}

# TODO: Do we actually use the offset when fitting?
Expand Down Expand Up @@ -175,10 +178,12 @@
.convert_form_to_xy_new <- function(object,
new_data,
na.action = na.pass,
composition = "data.frame") {
composition = "data.frame",
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame} or {.val matrix}."
"{.arg composition} should be either {.val data.frame} or {.val matrix}.",
call = call
)
}

Expand Down Expand Up @@ -244,9 +249,10 @@
y,
weights = NULL,
y_name = "..y",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (is.vector(x)) {
cli::cli_abort("{.arg x} cannot be a vector.")
cli::cli_abort("{.arg x} cannot be a vector.", call = call)
}

if (remove_intercept) {
Expand Down Expand Up @@ -279,10 +285,10 @@

if (!is.null(weights)) {
if (!is.numeric(weights)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}
if (length(weights) != nrow(x)) {
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.")
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call)
}

form <- patch_formula_environment_with_case_weights(
Expand Down
16 changes: 9 additions & 7 deletions R/descriptors.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,23 @@ NULL

# Descriptor retrievers --------------------------------------------------------

get_descr_form <- function(formula, data) {
get_descr_form <- function(formula, data, call = rlang::caller_env()) {
if (inherits(data, "tbl_spark")) {
res <- get_descr_spark(formula, data)
} else {
res <- get_descr_df(formula, data)
res <- get_descr_df(formula, data, call = call)
}
res
}

get_descr_df <- function(formula, data) {
get_descr_df <- function(formula, data, call = rlang::caller_env()) {

tmp_dat <-
.convert_form_to_xy_fit(formula,
data,
indicators = "none",
remove_intercept = TRUE)
remove_intercept = TRUE,
call = call)

if(is.factor(tmp_dat$y)) {
.lvls <- function() {
Expand All @@ -136,7 +137,8 @@ get_descr_df <- function(formula, data) {
formula,
data,
indicators = "traditional",
remove_intercept = TRUE
remove_intercept = TRUE,
call = call
)$x
)
}
Expand Down Expand Up @@ -263,7 +265,7 @@ get_descr_spark <- function(formula, data) {
)
}

get_descr_xy <- function(x, y) {
get_descr_xy <- function(x, y, call = rlang::caller_env()) {

.lvls <- if (is.factor(y)) {
function() table(y, dnn = NULL)
Expand Down Expand Up @@ -291,7 +293,7 @@ get_descr_xy <- function(x, y) {
}

.dat <- function() {
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE)$data
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE, call = call)$data
}

.x <- function() {
Expand Down
5 changes: 3 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ fit.model_spec <-
}

if (all(c("x", "y") %in% names(dots))) {
cli::cli_abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
}
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
Expand Down Expand Up @@ -307,7 +307,8 @@ fit_xy.model_spec <-

if (object$engine == "spark") {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
"spark objects can only be used with the formula interface to {.fn fit}
with a spark data object."
)
}

Expand Down
9 changes: 5 additions & 4 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ form_form <-

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_form(env$formula, env$data)
data_stats <- get_descr_form(env$formula, env$data, call = call)
scoped_descrs(data_stats)
}

Expand Down Expand Up @@ -86,7 +86,7 @@ xy_xy <- function(object,

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_xy(env$x, env$y)
data_stats <- get_descr_xy(env$x, env$y, call = call)
scoped_descrs(data_stats)
}

Expand All @@ -96,7 +96,7 @@ xy_xy <- function(object,
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

fit_call <- make_xy_call(object, target, env)
fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)

Expand Down Expand Up @@ -141,7 +141,8 @@ form_xy <- function(object, control, env,
...,
composition = target,
indicators = indicators,
remove_intercept = remove_intercept
remove_intercept = remove_intercept,
call = call
)
env$x <- data_obj$x
env$y <- data_obj$y
Expand Down
Loading
Loading