Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ S3method(tunable,logistic_reg)
S3method(tunable,mars)
S3method(tunable,mlp)
S3method(tunable,model_spec)
S3method(tunable,multinomial_reg)
S3method(tunable,multinom_reg)
S3method(tunable,rand_forest)
S3method(tunable,survival_reg)
S3method(tunable,svm_poly)
Expand Down
11 changes: 11 additions & 0 deletions R/mlp_brulee_two_layer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#' Multilayer perceptron via brulee with two hidden layers
#'
#' [brulee::brulee_mlp_two_layer()] fits a neural network (with version 0.3.0.9000 or higher of brulee)
#'
#' @includeRmd man/rmd/mlp_brulee_two_layer.md details
#'
#' @name details_mlp_brulee_two_layer
#' @keywords internal
NULL

# See inst/README-DOCS.md for a description of how these files are processed
166 changes: 165 additions & 1 deletion R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ set_pred(

set_model_engine("mlp", "classification", "brulee")
set_model_engine("mlp", "regression", "brulee")
set_dependency("mlp", "brulee", "brulee")
set_dependency("mlp", "brulee", "brulee", mode = "classification")
set_dependency("mlp", "brulee", "brulee", mode = "regression")

set_model_arg(
model = "mlp",
Expand Down Expand Up @@ -527,3 +528,166 @@ set_pred(
)
)


set_model_engine("mlp", "classification", "brulee_two_layer")
set_model_engine("mlp", "regression", "brulee_two_layer")
set_dependency("mlp", "brulee_two_layer", "brulee", mode = "classification")
set_dependency("mlp", "brulee_two_layer", "brulee", mode = "regression")

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "hidden_units",
original = "hidden_units",
func = list(pkg = "dials", fun = "hidden_units"),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "epochs",
original = "epochs",
func = list(pkg = "dials", fun = "epochs"),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "dropout",
original = "dropout",
func = list(pkg = "dials", fun = "dropout"),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "learn_rate",
original = "learn_rate",
func = list(pkg = "dials", fun = "learn_rate", range = c(-2.5, -0.5)),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee_two_layer",
parsnip = "activation",
original = "activation",
func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')),
has_submodel = FALSE
)


set_fit(
model = "mlp",
eng = "brulee_two_layer",
mode = "regression",
value = list(
interface = "data.frame",
protect = c("x", "y"),
func = c(pkg = "brulee", fun = "brulee_mlp_two_layer"),
defaults = list()
)
)

set_encoding(
model = "mlp",
eng = "brulee_two_layer",
mode = "regression",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

set_fit(
model = "mlp",
eng = "brulee_two_layer",
mode = "classification",
value = list(
interface = "data.frame",
protect = c("x", "y"),
func = c(pkg = "brulee", fun = "brulee_mlp_two_layer"),
defaults = list()
)
)

set_encoding(
model = "mlp",
eng = "brulee_two_layer",
mode = "classification",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

set_pred(
model = "mlp",
eng = "brulee_two_layer",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = reformat_torch_num,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
type = "numeric"
)
)
)

set_pred(
model = "mlp",
eng = "brulee_two_layer",
mode = "classification",
type = "class",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
type = "class"
)
)
)

set_pred(
model = "mlp",
eng = "brulee_two_layer",
mode = "classification",
type = "prob",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
type = "prob"
)
)
)

128 changes: 76 additions & 52 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,37 +194,6 @@ earth_engine_args <-
component_id = "engine"
)

brulee_mlp_engine_args <-
tibble::tribble(
~name, ~call_info,
"momentum", list(pkg = "dials", fun = "momentum", range = c(0.5, 0.95)),
"batch_size", list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
"stop_iter", list(pkg = "dials", fun = "stop_iter"),
"class_weights", list(pkg = "dials", fun = "class_weights"),
"decay", list(pkg = "dials", fun = "rate_decay"),
"initial", list(pkg = "dials", fun = "rate_initial"),
"largest", list(pkg = "dials", fun = "rate_largest"),
"rate_schedule", list(pkg = "dials", fun = "rate_schedule"),
"step_size", list(pkg = "dials", fun = "rate_step_size"),
"mixture", list(pkg = "dials", fun = "mixture")
) %>%
dplyr::mutate(source = "model_spec",
component = "mlp",
component_id = "engine"
)

brulee_linear_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter"))

brulee_logistic_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))

brulee_multinomial_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))

flexsurvspline_engine_args <-
tibble::tibble(
name = c("k"),
Expand All @@ -236,6 +205,42 @@ flexsurvspline_engine_args <-
component_id = "engine"
)

# ------------------------------------------------------------------------------
# used for brulee engines:

tune_activations <- c("relu", "tanh", "elu", "log_sigmoid", "tanhshrink")
tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")

brulee_mlp_args <-
tibble::tibble(
name = c('epochs', 'hidden_units', 'hidden_units_2', 'activation', 'activation_2',
'penalty', 'mixture', 'dropout', 'learn_rate', 'momentum', 'batch_size',
'class_weights', 'stop_iter', 'rate_schedule'),
call_info = list(
list(pkg = "dials", fun = "epochs", range = c(5L, 500L)),
list(pkg = "dials", fun = "hidden_units", range = c(2L, 50L)),
list(pkg = "dials", fun = "hidden_units_2", range = c(2L, 50L)),
list(pkg = "dials", fun = "activation", values = tune_activations),
list(pkg = "dials", fun = "activation_2", values = tune_activations),
list(pkg = "dials", fun = "penalty"),
list(pkg = "dials", fun = "mixture"),
list(pkg = "dials", fun = "dropout"),
list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/5)),
list(pkg = "dials", fun = "momentum", range = c(0.50, 0.95)),
list(pkg = "dials", fun = "batch_size"),
list(pkg = "dials", fun = "stop_iter"),
list(pkg = "dials", fun = "class_weights"),
list(pkg = "dials", fun = "rate_schedule", values = tune_sched)
)
) %>%
dplyr::mutate(source = "model_spec")

brulee_mlp_only_args <-
tibble::tibble(
name =
c('hidden_units', 'hidden_units_2', 'activation', 'activation_2', 'dropout')
)

# ------------------------------------------------------------------------------

#' @export
Expand All @@ -245,31 +250,55 @@ tunable.linear_reg <- function(x, ...) {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_linear_engine_args)
res <-
brulee_mlp_args %>%
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
dplyr::filter(name != "class_weights") %>%
dplyr::mutate(
component = "linear_reg",
component_id = ifelse(name %in% names(formals("linear_reg")), "main", "engine")
) %>%
dplyr::select(name, call_info, source, component, component_id)
}
res
}

#' @export

#' @export
tunable.logistic_reg <- function(x, ...) {
res <- NextMethod()
if (x$engine == "glmnet") {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_logistic_engine_args)
res <-
brulee_mlp_args %>%
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
dplyr::mutate(
component = "logistic_reg",
component_id = ifelse(name %in% names(formals("logistic_reg")), "main", "engine")
) %>%
dplyr::select(name, call_info, source, component, component_id)
}
res
}

#' @export
tunable.multinomial_reg <- function(x, ...) {
tunable.multinom_reg <- function(x, ...) {
res <- NextMethod()
if (x$engine == "glmnet") {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_multinomial_engine_args)
res <-
brulee_mlp_args %>%
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
dplyr::mutate(
component = "multinom_reg",
component_id = ifelse(name %in% names(formals("multinom_reg")), "main", "engine")
) %>%
dplyr::select(name, call_info, source, component, component_id)
}
res
}
Expand Down Expand Up @@ -345,28 +374,23 @@ tunable.svm_poly <- function(x, ...) {
res
}


#' @export
tunable.mlp <- function(x, ...) {
res <- NextMethod()
if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_mlp_engine_args)
res$call_info[res$name == "learn_rate"] <-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
res$call_info[res$name == "epochs"] <-
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
activation_values <- rlang::eval_tidy(
rlang::call2("brulee_activations", .ns = "brulee")
)
res$call_info[res$name == "activation"] <-
list(list(pkg = "dials", fun = "activation", values = activation_values))
} else if (x$engine == "keras") {
activation_values <- parsnip::keras_activations()
res$call_info[res$name == "activation"] <-
list(list(pkg = "dials", fun = "activation", values = activation_values))
if (grepl("brulee", x$engine)) {
res <-
brulee_mlp_args %>%
dplyr::mutate(
component = "mlp",
component_id = ifelse(name %in% names(formals("mlp")), "main", "engine")
) %>%
dplyr::select(name, call_info, source, component, component_id)
if (x$engine == "brulee") {
res <- res[!grepl("_2", res$name),]
}
}
res
}
}

#' @export
tunable.survival_reg <- function(x, ...) {
Expand Down
Loading
Loading