Skip to content

Commit 57e8c10

Browse files
authored
Merge pull request #199 from tidymodels/rlang-tibble-updates
rlang and tibble updates
2 parents 6d0a5e7 + 373cefd commit 57e8c10

23 files changed

+159
-46
lines changed

.travis.yml

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,36 @@ matrix:
2222
allow_failures:
2323
- r: 3.3 # inum install failure (seg fault)
2424
- r: 3.2 # partykit install failure (libcoin needs >= 3.4.0)
25-
- r: 3.4 # mvtnorm requires >= 3.5.0
25+
- r: 3.4 # mvtnorm requires >= 3.5.0
26+
27+
r_binary_packages:
28+
- RCurl
29+
- dplyr
30+
- glue
31+
- magrittr
32+
- stringi
33+
- stringr
34+
- munsell
35+
- rlang
36+
- reshape2
37+
- scales
38+
- tibble
39+
- ggplot2
40+
- Rcpp
41+
- RcppEigen
42+
- BH
43+
- glmnet
44+
- earth
45+
- sparklyr
46+
- flexsurv
47+
- ranger
48+
- randomforest
49+
- xgboost
50+
- C50
51+
2652

2753
cache:
54+
packages: true
2855
directories:
2956
- $HOME/.keras
3057
- $HOME/.cache/pip

DESCRIPTION

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Imports:
2828
stats,
2929
tidyr,
3030
globals,
31-
vctrs
31+
vctrs (>= 0.2.0)
3232
Roxygen: list(markdown = TRUE)
3333
RoxygenNote: 6.1.1
3434
Suggests:
@@ -50,4 +50,3 @@ Suggests:
5050
rpart,
5151
MASS,
5252
nlme
53-
Remotes: r-lib/vctrs

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Unplanned release based on CRAN requirements for Solaris.
2222

2323
* A suite of internal functions were added to help with upcoming model tuning features.
2424

25+
* A `parsnip` object always saved the name(s) of the outcome variable(s) for proper naming of the predicted values.
26+
2527

2628
# parsnip 0.0.2
2729

R/aaa.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11

22
maybe_multivariate <- function(results, object) {
3-
if (isTRUE(ncol(results) > 1))
4-
results <- as_tibble(results)
5-
else
3+
4+
if (isTRUE(ncol(results) > 1)) {
5+
nms <- colnames(results)
6+
results <- as_tibble(results, .name_repair = "minimal")
7+
if (length(nms) == 0 && length(object$preproc$y_var) == ncol(results)) {
8+
names(results) <- object$preproc$y_var
9+
}
10+
} else {
611
results <- unname(results[, 1])
12+
}
713
results
814
}
915

R/arguments.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#' @importFrom rlang expr enquos enquo quos is_quosure call2 quo_get_expr ll
22
#' @importFrom rlang abort current_env get_expr is_missing is_null is_symbolic missing_arg
33
null_value <- function(x) {
4-
res <- if(is_quosure(x))
5-
isTRUE(all.equal(x[[-1]], quote(NULL))) else
6-
isTRUE(all.equal(x, NULL))
4+
if (is_quosure(x)) {
5+
res <- isTRUE(all.equal(rlang::get_expr(x), expr(NULL)))
6+
} else {
7+
res <- isTRUE(all.equal(x, NULL))
8+
}
79
res
810
}
911

R/fit_helpers.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ form_form <-
5959
env = env,
6060
...
6161
)
62-
res$preproc <- NA
62+
res$preproc <- list(y_var = all.vars(env$formula[[2]]))
6363
res
6464
}
6565

@@ -114,7 +114,12 @@ xy_xy <- function(object, env, control, target = "none", ...) {
114114
env = env,
115115
...
116116
)
117-
res$preproc <- NA
117+
if (is.vector(env$y)) {
118+
y_name <- character(0)
119+
} else {
120+
y_name <- colnames(env$y)
121+
}
122+
res$preproc <- list(y_var = y_name)
118123
res
119124
}
120125

@@ -144,6 +149,7 @@ form_xy <- function(object, control, env,
144149
control = control,
145150
target = target
146151
)
152+
data_obj$y_var <- all.vars(env$formula[[2]])
147153
data_obj$x <- NULL
148154
data_obj$y <- NULL
149155
data_obj$weights <- NULL
@@ -177,7 +183,12 @@ xy_form <- function(object, env, control, ...) {
177183
control = control,
178184
...
179185
)
180-
res$preproc <- data_obj["x_var"]
186+
if (is.vector(env$y)) {
187+
data_obj$y_var <- character(0)
188+
} else {
189+
data_obj$y_var <- colnames(env$y)
190+
}
191+
res$preproc <- data_obj[c("x_var", "y_var")]
181192
res
182193
}
183194

R/logistic_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ set_pred(
333333
value = list(
334334
pre = NULL,
335335
post = function(x, object) {
336-
x <- as_tibble(x)
337336
colnames(x) <- object$lvl
337+
x <- as_tibble(x)
338338
x
339339
},
340340
func = c(pkg = "keras", fun = "predict_proba"),

R/mars.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
#' @examples
6060
#' mars(mode = "regression", num_terms = 5)
6161
#' @export
62-
6362
mars <-
6463
function(mode = "unknown",
6564
num_terms = NULL, prod_degree = NULL, prune_method = NULL) {
@@ -149,7 +148,7 @@ translate.mars <- function(x, engine = x$engine, ...) {
149148
# see if it is there and, if not, add the default value.
150149
if (x$mode == "classification") {
151150
if (!("glm" %in% names(x$eng_args))) {
152-
x$eng_args$glm <- quote(list(family = stats::binomial))
151+
x$eng_args$glm <- rlang::quo(list(family = stats::binomial))
153152
}
154153
}
155154

@@ -193,8 +192,8 @@ earth_reg_updater <- function(num, object, new_data, ...) {
193192
if (ncol(pred) == 1) {
194193
res <- tibble::tibble(.pred = pred[, 1], nprune = num)
195194
} else {
196-
res <- tibble::as_tibble(res)
197195
names(res) <- paste0(".pred_", names(res))
196+
res <- tibble::as_tibble(res)
198197
res$nprune <- num
199198
}
200199
res

R/misc.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,15 @@ check_outcome <- function(y, spec) {
223223
invisible(NULL)
224224
}
225225

226+
227+
# Get's a character string of varible names used as the outcome
228+
# in a terms object
229+
terms_y <- function(x) {
230+
att <- attributes(x)
231+
resp_ind <- att$response
232+
y_expr <- att$predvars[[resp_ind + 1]]
233+
all.vars(y_expr)
234+
}
235+
236+
237+

R/mlp.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ keras_mlp <-
370370
fit_call <- rlang::call_modify(fit_call, !!!arg_values$fit)
371371

372372
history <- eval_tidy(fit_call)
373+
model$y_names <- colnames(y)
373374
model
374375
}
375376

@@ -379,8 +380,9 @@ nnet_softmax <- function(results, object) {
379380
results <- cbind(1 - results, results)
380381

381382
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
382-
results <- as_tibble(t(results))
383+
results <- t(results)
383384
names(results) <- paste0(".pred_", object$lvl)
385+
results <- as_tibble(results)
384386
results
385387
}
386388

@@ -419,7 +421,7 @@ parse_keras_args <- function(...) {
419421
}
420422

421423
mlp_num_weights <- function(p, hidden_units, classes) {
422-
((p+1) * hidden_units) + ((hidden_units+1) * classes)
424+
((p + 1) * hidden_units) + ((hidden_units+1) * classes)
423425
}
424426

425427

0 commit comments

Comments
 (0)