Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ CmdStanModel <- R6::R6Class(
precompile_cpp_options_ = NULL,
precompile_stanc_options_ = NULL,
precompile_include_paths_ = NULL,
variables_ = NULL
variables_ = NULL,
cmdstan_version_ = NULL
),
public = list(
functions = NULL,
Expand Down Expand Up @@ -271,8 +272,14 @@ CmdStanModel <- R6::R6Class(
if (!is.null(stan_file) && compile) {
self$compile(...)
}

# for now, set this based on current version
# at initialize so its never null
# in the future, will be set only if/when we have a binary
# as the version the model was compiled with
private$cmdstan_version_ <- cmdstan_version()
if (length(self$exe_file()) > 0 && file.exists(self$exe_file())) {
cpp_options <- model_compile_info(self$exe_file())
cpp_options <- model_compile_info(self$exe_file(), self$cmdstan_version())
for (cpp_option_name in names(cpp_options)) {
if (cpp_option_name != "stan_version" &&
(!is.logical(cpp_options[[cpp_option_name]]) || isTRUE(cpp_options[[cpp_option_name]]))) {
Expand Down Expand Up @@ -328,6 +335,9 @@ CmdStanModel <- R6::R6Class(
}
private$exe_file_
},
cmdstan_version = function() {
private$cmdstan_version_
},
cpp_options = function() {
private$cpp_options_
},
Expand Down Expand Up @@ -737,6 +747,7 @@ compile <- function(quiet = TRUE,
con = wsl_safe_path(private$hpp_file_, revert = TRUE))
} # End - if(!dry_run)

private$cmdstan_version_ <- cmdstan_version()
private$exe_file_ <- exe
private$cpp_options_ <- cpp_options
private$precompile_cpp_options_ <- NULL
Expand Down Expand Up @@ -786,7 +797,7 @@ CmdStanModel$set("public", name = "compile", value = compile)
#' }
#'
variables <- function() {
if (cmdstan_version() < "2.27.0") {
if (self$cmdstan_version() < "2.27.0") {
stop("$variables() is only supported for CmdStan 2.27 or newer.", call. = FALSE)
}
if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -993,6 +1004,10 @@ format <- function(overwrite_file = FALSE,
backup = TRUE,
max_line_length = NULL,
quiet = FALSE) {
# querying current version here not model object version
# because this is pre-compile work based on the cmdstanr
# version that will be used to compile in teh future,
# not based on what was used to compile existing binary (if any)
if (cmdstan_version() < "2.29.0" && !is.null(max_line_length)) {
stop(
"'max_line_length' is only supported with CmdStan 2.29.0 or newer.",
Expand Down Expand Up @@ -1208,7 +1223,7 @@ sample <- function(data = NULL,
}
}

if (cmdstan_version() >= "2.27.0" && cmdstan_version() < "2.36.0" && !fixed_param) {
if (self$cmdstan_version() >= "2.27.0" && self$cmdstan_version() < "2.36.0" && !fixed_param) {
if (self$has_stan_file() && file.exists(self$stan_file())) {
if (!is.null(self$variables()) && length(self$variables()$parameters) == 0) {
stop("Model contains no parameters. Please use 'fixed_param = TRUE'.", call. = FALSE)
Expand Down Expand Up @@ -1652,7 +1667,7 @@ laplace <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
save_cmdstan_config = NULL) {
if (cmdstan_version() < "2.32") {
if (self$cmdstan_version() < "2.32") {
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
}
if (!is.null(mode) && !is.null(opt_args)) {
Expand Down Expand Up @@ -2382,9 +2397,9 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F
variables
}

model_compile_info <- function(exe_file) {
model_compile_info <- function(exe_file, version) {
info <- NULL
if (cmdstan_version() > "2.26.1") {
if (version > "2.26.1") {
withr::with_path(
c(
toolchain_PATH_env_var(),
Expand Down
14 changes: 11 additions & 3 deletions R/path.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,19 @@ unset_cmdstan_path <- function() {
}

# fake a cmdstan version (only used in tests)
fake_cmdstan_version <- function(version) {
fake_cmdstan_version <- function(version, mod = NULL) {
.cmdstanr$VERSION <- version
if (!is.null(mod)) {
if (!is.null(mod$.__enclos_env__$private$exe_info_)) {
mod$.__enclos_env__$private$exe_info_$stan_version <- version
}
if (!is.null(mod$.__enclos_env__$private$cmdstan_version_)) {
mod$.__enclos_env__$private$cmdstan_version_ <- version
}
}
}
reset_cmdstan_version <- function() {
.cmdstanr$VERSION <- read_cmdstan_version(cmdstan_path())
reset_cmdstan_version <- function(mod = NULL) {
fake_cmdstan_version(read_cmdstan_version(cmdstan_path()), mod = mod)
}

.home_path <- function() {
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test-model-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ test_that("Correct behavior if fixed_param not set when the model has no paramet
"
stan_file <- write_stan_file(code)
m <- cmdstan_model(stan_file)
fake_cmdstan_version("2.35.0")
fake_cmdstan_version("2.35.0", m)
expect_error(
m$sample(),
"Model contains no parameters. Please use 'fixed_param = TRUE'."
)

reset_cmdstan_version()
reset_cmdstan_version(m)
if (cmdstan_version() >= "2.36.0") {
# as of 2.36.0 we don't need fixed_param if no parameters
expect_no_error(
Expand All @@ -334,13 +334,13 @@ test_that("Correct behavior if fixed_param not set when the model has no paramet
})

test_that("sig_figs warning if version less than 2.25", {
fake_cmdstan_version("2.24.0")
fake_cmdstan_version("2.24.0", mod)
expect_warning(
expect_sample_output(mod$sample(data = data_list, chains = 1, refresh = 0, sig_figs = 3)),
"The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!",
fixed = TRUE
)
reset_cmdstan_version()
reset_cmdstan_version(mod)
})

test_that("Errors are suppressed with show_exceptions", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-model-variables.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set_cmdstan_path()
test_that("$variables() errors if version less than 2.27", {
mod <- testing_model("bernoulli")
ver <- cmdstan_version()
.cmdstanr$VERSION <- "2.26.0"
fake_cmdstan_version("2.26.0", mod = mod)
expect_error(
mod$variables(),
"$variables() is only supported for CmdStan 2.27 or newer",
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-opencl.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ test_that("all methods run with valid opencl_ids", {

test_that("error for runtime selection of OpenCL devices if version less than 2.26", {
skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true"))
fake_cmdstan_version("2.25.0")

stan_file <- testing_stan_file("bernoulli")
mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE),
force_recompile = TRUE)
fake_cmdstan_version("2.25.0", mod)
expect_error(
mod$sample(data = testing_data("bernoulli"), chains = 1, refresh = 0, opencl_ids = c(1,1)),
"Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer",
fixed = TRUE
)
reset_cmdstan_version()
reset_cmdstan_version(mod)
})
Loading