Skip to content

Commit 3424cf1

Browse files
authored
Redo unit tests (#75)
* expand unit tests for class weights * different testing strategy * different testing strategy * remove regression test files * convert rlang calls to cli
1 parent 88d6002 commit 3424cf1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1118
-740
lines changed

R/checks.R

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ check_missing_data <- function(x, y, fn = "some function", verbose = FALSE) {
77
cl_chr <- as.character()
88
msg <- paste0(fn, "() removed ", sum(!compl_data), " rows of ",
99
"data due to missing values.")
10-
rlang::warn(msg)
10+
cli::cli_warn(msg)
1111
}
1212
}
1313
list(x = x, y = y)
@@ -18,14 +18,14 @@ check_data_att <- function(x, y) {
1818

1919
# check matrices/vectors, matrix type, matrix column names
2020
if (!is.matrix(x) || !is.numeric(x)) {
21-
rlang::abort("'x' should be a numeric matrix.")
21+
cli::cli_abort("'x' should be a numeric matrix.")
2222
}
2323
nms <- colnames(x)
2424
if (length(nms) != ncol(x)) {
25-
rlang::abort("Every column of 'x' should have a name.")
25+
cli::cli_abort("Every column of 'x' should have a name.")
2626
}
2727
if (!is.vector(y) & !is.factor(y)) {
28-
rlang::abort("'y' should be a vector.")
28+
cli::cli_abort("'y' should be a vector.")
2929
}
3030
invisible(NULL)
3131
}
@@ -87,12 +87,12 @@ check_integer <-
8787

8888
if (!is.integer(x)) {
8989
msg <- paste(format_msg(fn, arg), "to be integer.")
90-
rlang::abort(msg)
90+
cli::cli_abort(msg)
9191
}
9292

9393
if (single && length(x) > 1) {
9494
msg <- paste(format_msg(fn, arg), "to be a single integer.")
95-
rlang::abort(msg)
95+
cli::cli_abort(msg)
9696
}
9797

9898
out_of_range <- check_rng(x, x_min, x_max, incl)
@@ -101,7 +101,7 @@ check_integer <-
101101
" to be an integer on ",
102102
ifelse(incl[[1]], "[", "("), x_min, ", ",
103103
x_max, ifelse(incl[[2]], "]", ")"), ".")
104-
rlang::abort(msg)
104+
cli::cli_abort(msg)
105105
}
106106

107107
invisible(TRUE)
@@ -116,12 +116,12 @@ check_double <- function(x,
116116

117117
if (!is.double(x)) {
118118
msg <- paste(format_msg(fn, arg), "to be a double.")
119-
rlang::abort(msg)
119+
cli::cli_abort(msg)
120120
}
121121

122122
if (single && length(x) > 1) {
123123
msg <- paste(format_msg(fn, arg), "to be a single double.")
124-
rlang::abort(msg)
124+
cli::cli_abort(msg)
125125
}
126126

127127
out_of_range <- check_rng(x, x_min, x_max, incl)
@@ -130,7 +130,7 @@ check_double <- function(x,
130130
" to be a double on ",
131131
ifelse(incl[[1]], "[", "("), x_min, ", ",
132132
x_max, ifelse(incl[[2]], "]", ")"), ".")
133-
rlang::abort(msg)
133+
cli::cli_abort(msg)
134134
}
135135

136136
invisible(TRUE)
@@ -142,18 +142,18 @@ check_character <- function(x, single = TRUE, vals = NULL, fn = NULL) {
142142

143143
if (!is.character(x)) {
144144
msg <- paste(format_msg(fn, arg), "to be character.")
145-
rlang::abort(msg)
145+
cli::cli_abort(msg)
146146
}
147147

148148
if (single && length(x) > 1) {
149149
msg <- paste(format_msg(fn, arg), "to be a single character string.")
150-
rlang::abort(msg)
150+
cli::cli_abort(msg)
151151
}
152152

153153
if (!is.null(vals)) {
154154
if (any(!(x %in% vals))) {
155155
msg <- paste0(format_msg(fn, arg), " contains an incorrect value.")
156-
rlang::abort(msg)
156+
cli::cli_abort(msg)
157157
}
158158
}
159159

@@ -166,12 +166,12 @@ check_logical <- function(x, single = TRUE, fn = NULL) {
166166

167167
if (!is.logical(x)) {
168168
msg <- paste(format_msg(fn, arg), "to be logical.")
169-
rlang::abort(msg)
169+
cli::cli_abort(msg)
170170
}
171171

172172
if (single && length(x) > 1) {
173173
msg <- paste(format_msg(fn, arg), "to be a single logical.")
174-
rlang::abort(msg)
174+
cli::cli_abort(msg)
175175
}
176176
invisible(TRUE)
177177
}
@@ -188,7 +188,7 @@ check_class_weights <- function(wts, lvls, xtab, fn) {
188188
}
189189
if (!is.numeric(wts)) {
190190
msg <- paste(format_msg(fn, "class_weights"), "to a numeric vector")
191-
rlang::abort(msg)
191+
cli::cli_abort(msg)
192192
}
193193

194194
if (length(wts) == 1) {
@@ -202,6 +202,7 @@ check_class_weights <- function(wts, lvls, xtab, fn) {
202202
if (length(lvls) != length(wts)) {
203203
msg <- paste0("There were ", length(wts), " class weights given but ",
204204
length(lvls), " were expected.")
205+
cli::cli_abort(msg)
205206
}
206207

207208
nms <- names(wts)
@@ -211,7 +212,7 @@ check_class_weights <- function(wts, lvls, xtab, fn) {
211212
if (!identical(sort(nms), sort(lvls))) {
212213
msg <- paste("Names for class weights should be:",
213214
paste0("'", lvls, "'", collapse = ", "))
214-
rlang::abort(msg)
215+
cli::cli_abort(msg)
215216
}
216217
wts <- wts[lvls]
217218
}

R/coef.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
brulee_coefs <- function(object, epoch = NULL, ...) {
22
if (!is.null(epoch) && length(epoch) != 1) {
3-
rlang::abort("'epoch' should be a single integer.")
3+
cli::cli_abort("'epoch' should be a single integer.")
44
}
55
max_epochs <- length(object$estimates)
66

@@ -9,7 +9,7 @@ brulee_coefs <- function(object, epoch = NULL, ...) {
99
} else {
1010
if (epoch > max_epochs) {
1111
msg <- glue::glue("There were only {max_epochs} epochs fit. Setting 'epochs' to {max_epochs}.")
12-
rlang::warn(msg)
12+
cli::cli_warn(msg)
1313
epoch <- max_epochs
1414
}
1515

R/convert_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ matrix_to_dataset <- function(x, y) {
3030
scale_stats <- function(x) {
3131
res <- list(mean = mean(x, na.rm = TRUE), sd = stats::sd(x, na.rm = TRUE))
3232
if (res$sd == 0) {
33-
rlang::abort("There is no variation in `y`.")
33+
cli::cli_abort("There is no variation in `y`.")
3434
}
3535
res
3636
}

R/linear_reg-fit.R

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ brulee_linear_reg_bridge <- function(processed, epochs, optimizer,
295295
learn_rate, momentum, penalty, mixture, dropout,
296296
validation, batch_size, stop_iter, verbose, ...) {
297297
if(!torch::torch_is_installed()) {
298-
rlang::abort("The torch backend has not been installed; use `torch::install_torch()`.")
298+
cli::cli_abort("The torch backend has not been installed; use `torch::install_torch()`.")
299299
}
300300

301301
f_nm <- "brulee_linear_reg"
@@ -325,7 +325,7 @@ brulee_linear_reg_bridge <- function(processed, epochs, optimizer,
325325
if (!is.matrix(predictors)) {
326326
predictors <- as.matrix(predictors)
327327
if (is.character(predictors)) {
328-
rlang::abort(
328+
cli::cli_abort(
329329
paste(
330330
"There were some non-numeric columns in the predictors.",
331331
"Please use a formula or recipe to encode all of the predictors as numeric."
@@ -371,22 +371,22 @@ brulee_linear_reg_bridge <- function(processed, epochs, optimizer,
371371
new_brulee_linear_reg <- function( model_obj, estimates, best_epoch, loss,
372372
dims, y_stats, parameters, blueprint) {
373373
if (!inherits(model_obj, "raw")) {
374-
rlang::abort("'model_obj' should be a raw vector.")
374+
cli::cli_abort("'model_obj' should be a raw vector.")
375375
}
376376
if (!is.list(estimates)) {
377-
rlang::abort("'parameters' should be a list")
377+
cli::cli_abort("'parameters' should be a list")
378378
}
379379
if (!is.vector(loss) || !is.numeric(loss)) {
380-
rlang::abort("'loss' should be a numeric vector")
380+
cli::cli_abort("'loss' should be a numeric vector")
381381
}
382382
if (!is.list(dims)) {
383-
rlang::abort("'dims' should be a list")
383+
cli::cli_abort("'dims' should be a list")
384384
}
385385
if (!is.list(parameters)) {
386-
rlang::abort("'parameters' should be a list")
386+
cli::cli_abort("'parameters' should be a list")
387387
}
388388
if (!inherits(blueprint, "hardhat_blueprint")) {
389-
rlang::abort("'blueprint' should be a hardhat blueprint")
389+
cli::cli_abort("'blueprint' should be a hardhat blueprint")
390390
}
391391
hardhat::new_model(model_obj = model_obj,
392392
estimates = estimates,
@@ -453,7 +453,7 @@ linear_reg_fit_imp <-
453453
loss_label <- "\tLoss (scaled):"
454454

455455
if (optimizer == "LBFGS" & !is.null(batch_size)) {
456-
rlang::warn("'batch_size' is only used for the SGD optimizer.")
456+
cli::cli_warn("'batch_size' is only used for the SGD optimizer.")
457457
batch_size <- NULL
458458
}
459459
if (is.null(batch_size)) {
@@ -524,7 +524,7 @@ linear_reg_fit_imp <-
524524
loss_vec[epoch] <- loss_curr
525525

526526
if (is.nan(loss_curr)) {
527-
rlang::warn("Current loss in NaN. Training wil be stopped.")
527+
cli::cli_warn("Current loss in NaN. Training wil be stopped.")
528528
break()
529529
}
530530

@@ -548,7 +548,7 @@ linear_reg_fit_imp <-
548548
msg <- paste("epoch:", epoch_chr[epoch], loss_label,
549549
signif(loss_curr, 3), loss_note)
550550

551-
rlang::inform(msg)
551+
cli::cli_inform(msg)
552552
}
553553

554554
if (poor_epoch == stop_iter) {

R/linear_reg-predict.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ predict_brulee_linear_reg_bridge <- function(type, model, predictors, epoch) {
5757
if (!is.matrix(predictors)) {
5858
predictors <- as.matrix(predictors)
5959
if (is.character(predictors)) {
60-
rlang::abort(
60+
cli::cli_abort(
6161
paste(
6262
"There were some non-numeric columns in the predictors.",
6363
"Please use a formula or recipe to encode all of the predictors as numeric."
@@ -72,7 +72,7 @@ predict_brulee_linear_reg_bridge <- function(type, model, predictors, epoch) {
7272
if (epoch > max_epoch) {
7373
msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot",
7474
"be made at epoch", epoch, "so last epoch is used.")
75-
rlang::warn(msg)
75+
cli::cli_warn(msg)
7676
}
7777

7878
predictions <- predict_function(model, predictors, epoch)

R/logistic_reg-fit.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ brulee_logistic_reg_bridge <- function(processed, epochs, optimizer,
290290
learn_rate, momentum, penalty, mixture, class_weights,
291291
validation, batch_size, stop_iter, verbose, ...) {
292292
if(!torch::torch_is_installed()) {
293-
rlang::abort("The torch backend has not been installed; use `torch::install_torch()`.")
293+
cli::cli_abort("The torch backend has not been installed; use `torch::install_torch()`.")
294294
}
295295

296296
f_nm <- "brulee_logistic_reg"
@@ -319,7 +319,7 @@ brulee_logistic_reg_bridge <- function(processed, epochs, optimizer,
319319
if (!is.matrix(predictors)) {
320320
predictors <- as.matrix(predictors)
321321
if (is.character(predictors)) {
322-
rlang::abort(
322+
cli::cli_abort(
323323
paste(
324324
"There were some non-numeric columns in the predictors.",
325325
"Please use a formula or recipe to encode all of the predictors as numeric."
@@ -337,7 +337,7 @@ brulee_logistic_reg_bridge <- function(processed, epochs, optimizer,
337337

338338
outcome <- processed$outcomes[[1]]
339339
if (length(levels(outcome)) > 2) {
340-
rlang::abort("logistic regression is for outcomes with two classes.")
340+
cli::cli_abort("logistic regression is for outcomes with two classes.")
341341
}
342342

343343
# ------------------------------------------------------------------------------
@@ -380,22 +380,22 @@ brulee_logistic_reg_bridge <- function(processed, epochs, optimizer,
380380
new_brulee_logistic_reg <- function( model_obj, estimates, best_epoch, loss,
381381
dims, y_stats, parameters, blueprint) {
382382
if (!inherits(model_obj, "raw")) {
383-
rlang::abort("'model_obj' should be a raw vector.")
383+
cli::cli_abort("'model_obj' should be a raw vector.")
384384
}
385385
if (!is.list(estimates)) {
386-
rlang::abort("'parameters' should be a list")
386+
cli::cli_abort("'parameters' should be a list")
387387
}
388388
if (!is.vector(loss) || !is.numeric(loss)) {
389-
rlang::abort("'loss' should be a numeric vector")
389+
cli::cli_abort("'loss' should be a numeric vector")
390390
}
391391
if (!is.list(dims)) {
392-
rlang::abort("'dims' should be a list")
392+
cli::cli_abort("'dims' should be a list")
393393
}
394394
if (!is.list(parameters)) {
395-
rlang::abort("'parameters' should be a list")
395+
cli::cli_abort("'parameters' should be a list")
396396
}
397397
if (!inherits(blueprint, "hardhat_blueprint")) {
398-
rlang::abort("'blueprint' should be a hardhat blueprint")
398+
cli::cli_abort("'blueprint' should be a hardhat blueprint")
399399
}
400400
hardhat::new_model(model_obj = model_obj,
401401
estimates = estimates,
@@ -464,7 +464,7 @@ logistic_reg_fit_imp <-
464464
loss_label <- "\tLoss:"
465465

466466
if (optimizer == "LBFGS" & !is.null(batch_size)) {
467-
rlang::warn("'batch_size' is only used for the SGD optimizer.")
467+
cli::cli_warn("'batch_size' is only used for the SGD optimizer.")
468468
batch_size <- NULL
469469
}
470470
if (is.null(batch_size)) {
@@ -535,7 +535,7 @@ logistic_reg_fit_imp <-
535535
loss_vec[epoch] <- loss_curr
536536

537537
if (is.nan(loss_curr)) {
538-
rlang::warn("Current loss in NaN. Training wil be stopped.")
538+
cli::cli_warn("Current loss in NaN. Training wil be stopped.")
539539
break()
540540
}
541541

@@ -558,7 +558,7 @@ logistic_reg_fit_imp <-
558558
msg <- paste("epoch:", epoch_chr[epoch], loss_label,
559559
signif(loss_curr, 3), loss_note)
560560

561-
rlang::inform(msg)
561+
cli::cli_inform(msg)
562562
}
563563

564564
if (poor_epoch == stop_iter) {

R/logistic_reg-predict.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ predict_brulee_logistic_reg_bridge <- function(type, model, predictors, epoch) {
6464
if (!is.matrix(predictors)) {
6565
predictors <- as.matrix(predictors)
6666
if (is.character(predictors)) {
67-
rlang::abort(
67+
cli::cli_abort(
6868
paste(
6969
"There were some non-numeric columns in the predictors.",
7070
"Please use a formula or recipe to encode all of the predictors as numeric."
@@ -79,7 +79,7 @@ predict_brulee_logistic_reg_bridge <- function(type, model, predictors, epoch) {
7979
if (epoch > max_epoch) {
8080
msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot",
8181
"be made at epoch", epoch, "so last epoch is used.")
82-
rlang::warn(msg)
82+
cli::cli_warn(msg)
8383
}
8484

8585
predictions <- predict_function(model, predictors, epoch)

0 commit comments

Comments
 (0)