Skip to content

Commit b271b7b

Browse files
authored
Make exported RNG functions respect changes to R's seed (#973)
* Make exported RNG functions respect changes to R's seed * Simpler seed setting * Seed set location * Fix test
1 parent c218b0d commit b271b7b

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

R/utils.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -887,12 +887,16 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
887887
}
888888
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
889889
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
890-
if (cmdstan_version() < "2.35.0") {
891-
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
892-
} else {
893-
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
890+
if (grepl("(stan::rng_t|boost::ecuyer1988)", fun_body)) {
891+
if (cmdstan_version() < "2.35.0") {
892+
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body)
893+
} else {
894+
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body)
895+
}
896+
rng_seed <- "Rcpp::XPtr<stan::rng_t> base_rng(base_rng_ptr);base_rng->seed(Rcpp::as<int>(seed));"
897+
fun_body <- gsub("return", paste(rng_seed, "return"), fun_body)
898+
fun_body <- gsub("base_rng__,", "*(base_rng.get()),", fun_body, fixed = TRUE)
894899
}
895-
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<stan::rng_t>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
896900
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
897901
fun_body <- paste(fun_body, collapse = "\n")
898902
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
@@ -953,6 +957,9 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
953957
fundef <- get(fun, envir = fun_env)
954958
funargs <- formals(fundef)
955959
funargs$base_rng_ptr <- env$rng_ptr
960+
# To allow for exported RNG functions to respect the R 'set.seed()' call,
961+
# we need to derive a seed deterministically from the current RNG state
962+
funargs$seed <- quote(sample.int(.Machine$integer.max, 1))
956963
formals(fundef) <- funargs
957964
assign(fun, fundef, envir = fun_env)
958965
}

tests/testthat/test-model-expose-functions.R

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,16 @@ test_that("rng functions can be exposed", {
346346
mod <- cmdstan_model(model, force_recompile = TRUE)
347347
fit <- mod$sample(data = data_list)
348348

349-
set.seed(10)
350349
fit$expose_functions(verbose = TRUE)
350+
set.seed(10)
351+
res1_1 <- fit$functions$wrap_normal_rng(5,10)
352+
res2_1 <- fit$functions$wrap_normal_rng(5,10)
353+
set.seed(10)
354+
res1_2 <- fit$functions$wrap_normal_rng(5,10)
355+
res2_2 <- fit$functions$wrap_normal_rng(5,10)
351356

352-
expect_equal(
353-
fit$functions$wrap_normal_rng(5,10),
354-
# Stan RNG changed in 2.35
355-
ifelse(cmdstan_version() < "2.35.0",-4.529876423, 0.02974925)
356-
)
357-
358-
expect_equal(
359-
fit$functions$wrap_normal_rng(5,10),
360-
ifelse(cmdstan_version() < "2.35.0", 8.12959026, 10.3881349)
361-
)
357+
expect_equal(res1_1, res1_2)
358+
expect_equal(res2_1, res2_2)
362359
})
363360

364361
test_that("Overloaded functions give meaningful errors", {

0 commit comments

Comments
 (0)