diff --git a/R/model.R b/R/model.R index 323320b46..2a29560a9 100644 --- a/R/model.R +++ b/R/model.R @@ -233,7 +233,8 @@ CmdStanModel <- R6::R6Class( precompile_cpp_options_ = NULL, precompile_stanc_options_ = NULL, precompile_include_paths_ = NULL, - variables_ = NULL + variables_ = NULL, + cmdstan_version_ = NULL ), public = list( functions = NULL, @@ -271,8 +272,14 @@ CmdStanModel <- R6::R6Class( if (!is.null(stan_file) && compile) { self$compile(...) } + + # for now, set this based on current version + # at initialize so its never null + # in the future, will be set only if/when we have a binary + # as the version the model was compiled with + private$cmdstan_version_ <- cmdstan_version() if (length(self$exe_file()) > 0 && file.exists(self$exe_file())) { - cpp_options <- model_compile_info(self$exe_file()) + cpp_options <- model_compile_info(self$exe_file(), self$cmdstan_version()) for (cpp_option_name in names(cpp_options)) { if (cpp_option_name != "stan_version" && (!is.logical(cpp_options[[cpp_option_name]]) || isTRUE(cpp_options[[cpp_option_name]]))) { @@ -328,6 +335,9 @@ CmdStanModel <- R6::R6Class( } private$exe_file_ }, + cmdstan_version = function() { + private$cmdstan_version_ + }, cpp_options = function() { private$cpp_options_ }, @@ -737,6 +747,7 @@ compile <- function(quiet = TRUE, con = wsl_safe_path(private$hpp_file_, revert = TRUE)) } # End - if(!dry_run) + private$cmdstan_version_ <- cmdstan_version() private$exe_file_ <- exe private$cpp_options_ <- cpp_options private$precompile_cpp_options_ <- NULL @@ -786,7 +797,7 @@ CmdStanModel$set("public", name = "compile", value = compile) #' } #' variables <- function() { - if (cmdstan_version() < "2.27.0") { + if (self$cmdstan_version() < "2.27.0") { stop("$variables() is only supported for CmdStan 2.27 or newer.", call. = FALSE) } if (length(self$stan_file()) == 0) { @@ -993,6 +1004,10 @@ format <- function(overwrite_file = FALSE, backup = TRUE, max_line_length = NULL, quiet = FALSE) { + # querying current version here not model object version + # because this is pre-compile work based on the cmdstanr + # version that will be used to compile in teh future, + # not based on what was used to compile existing binary (if any) if (cmdstan_version() < "2.29.0" && !is.null(max_line_length)) { stop( "'max_line_length' is only supported with CmdStan 2.29.0 or newer.", @@ -1208,7 +1223,7 @@ sample <- function(data = NULL, } } - if (cmdstan_version() >= "2.27.0" && cmdstan_version() < "2.36.0" && !fixed_param) { + if (self$cmdstan_version() >= "2.27.0" && self$cmdstan_version() < "2.36.0" && !fixed_param) { if (self$has_stan_file() && file.exists(self$stan_file())) { if (!is.null(self$variables()) && length(self$variables()$parameters) == 0) { stop("Model contains no parameters. Please use 'fixed_param = TRUE'.", call. = FALSE) @@ -1652,7 +1667,7 @@ laplace <- function(data = NULL, show_messages = TRUE, show_exceptions = TRUE, save_cmdstan_config = NULL) { - if (cmdstan_version() < "2.32") { + if (self$cmdstan_version() < "2.32") { stop("This method is only available in cmdstan >= 2.32", call. = FALSE) } if (!is.null(mode) && !is.null(opt_args)) { @@ -2382,9 +2397,9 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F variables } -model_compile_info <- function(exe_file) { +model_compile_info <- function(exe_file, version) { info <- NULL - if (cmdstan_version() > "2.26.1") { + if (version > "2.26.1") { withr::with_path( c( toolchain_PATH_env_var(), diff --git a/R/path.R b/R/path.R index 15bbeae69..4353d7b12 100644 --- a/R/path.R +++ b/R/path.R @@ -234,11 +234,19 @@ unset_cmdstan_path <- function() { } # fake a cmdstan version (only used in tests) -fake_cmdstan_version <- function(version) { +fake_cmdstan_version <- function(version, mod = NULL) { .cmdstanr$VERSION <- version + if (!is.null(mod)) { + if (!is.null(mod$.__enclos_env__$private$exe_info_)) { + mod$.__enclos_env__$private$exe_info_$stan_version <- version + } + if (!is.null(mod$.__enclos_env__$private$cmdstan_version_)) { + mod$.__enclos_env__$private$cmdstan_version_ <- version + } + } } -reset_cmdstan_version <- function() { - .cmdstanr$VERSION <- read_cmdstan_version(cmdstan_path()) +reset_cmdstan_version <- function(mod = NULL) { + fake_cmdstan_version(read_cmdstan_version(cmdstan_path()), mod = mod) } .home_path <- function() { diff --git a/tests/testthat/test-model-sample.R b/tests/testthat/test-model-sample.R index 5c9632474..15ec118f5 100644 --- a/tests/testthat/test-model-sample.R +++ b/tests/testthat/test-model-sample.R @@ -316,13 +316,13 @@ test_that("Correct behavior if fixed_param not set when the model has no paramet " stan_file <- write_stan_file(code) m <- cmdstan_model(stan_file) - fake_cmdstan_version("2.35.0") + fake_cmdstan_version("2.35.0", m) expect_error( m$sample(), "Model contains no parameters. Please use 'fixed_param = TRUE'." ) - reset_cmdstan_version() + reset_cmdstan_version(m) if (cmdstan_version() >= "2.36.0") { # as of 2.36.0 we don't need fixed_param if no parameters expect_no_error( @@ -334,13 +334,13 @@ test_that("Correct behavior if fixed_param not set when the model has no paramet }) test_that("sig_figs warning if version less than 2.25", { - fake_cmdstan_version("2.24.0") + fake_cmdstan_version("2.24.0", mod) expect_warning( expect_sample_output(mod$sample(data = data_list, chains = 1, refresh = 0, sig_figs = 3)), "The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!", fixed = TRUE ) - reset_cmdstan_version() + reset_cmdstan_version(mod) }) test_that("Errors are suppressed with show_exceptions", { diff --git a/tests/testthat/test-model-variables.R b/tests/testthat/test-model-variables.R index a34b72dd1..8545e4456 100644 --- a/tests/testthat/test-model-variables.R +++ b/tests/testthat/test-model-variables.R @@ -5,7 +5,7 @@ set_cmdstan_path() test_that("$variables() errors if version less than 2.27", { mod <- testing_model("bernoulli") ver <- cmdstan_version() - .cmdstanr$VERSION <- "2.26.0" + fake_cmdstan_version("2.26.0", mod = mod) expect_error( mod$variables(), "$variables() is only supported for CmdStan 2.27 or newer", diff --git a/tests/testthat/test-opencl.R b/tests/testthat/test-opencl.R index d5774d998..265f638fa 100644 --- a/tests/testthat/test-opencl.R +++ b/tests/testthat/test-opencl.R @@ -107,15 +107,15 @@ test_that("all methods run with valid opencl_ids", { test_that("error for runtime selection of OpenCL devices if version less than 2.26", { skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true")) - fake_cmdstan_version("2.25.0") stan_file <- testing_stan_file("bernoulli") mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE), force_recompile = TRUE) + fake_cmdstan_version("2.25.0", mod) expect_error( mod$sample(data = testing_data("bernoulli"), chains = 1, refresh = 0, opencl_ids = c(1,1)), "Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer", fixed = TRUE ) - reset_cmdstan_version() + reset_cmdstan_version(mod) })