diff --git a/R/fit.R b/R/fit.R index 7f80f6d1b..d355dcb74 100644 --- a/R/fit.R +++ b/R/fit.R @@ -64,6 +64,10 @@ CmdStanFit <- R6::R6Class( "rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)\n") } invisible(self) + }, + expose_functions = function(global = FALSE, verbose = FALSE) { + expose_functions(self$functions, global, verbose) + invisible(NULL) } ), private = list( @@ -280,30 +284,6 @@ init <- function() { } CmdStanFit$set("public", name = "init", value = init) -expose_functions <- function(global = FALSE, verbose = FALSE) { - require_suggested_package("Rcpp") - require_suggested_package("RcppEigen") - require_suggested_package("decor") - if (self$functions$compiled) { - if (!global) { - message("Functions already compiled, nothing to do!") - } else { - message("Functions already compiled, copying to global environment") - # Create reference to global environment, avoids NOTE about assigning to global - pos <- 1 - envir = as.environment(pos) - lapply(self$functions$fun_names, function(fun_name) { - assign(fun_name, get(fun_name, self$functions), envir) - }) - } - } else { - message("Compiling standalone functions...") - expose_functions(self$functions, verbose, global) - } - invisible(NULL) -} -CmdStanFit$set("public", name = "expose_functions", value = expose_functions) - #' Compile additional methods for accessing the model log-probability function #' and parameter constraining and unconstraining. This requires the `Rcpp` package. #' diff --git a/R/model.R b/R/model.R index 4d90b390a..a1bfddc64 100644 --- a/R/model.R +++ b/R/model.R @@ -219,13 +219,15 @@ CmdStanModel <- R6::R6Class( precompile_cpp_options_ = NULL, precompile_stanc_options_ = NULL, precompile_include_paths_ = NULL, - variables_ = NULL, - standalone_env_ = NULL + variables_ = NULL ), public = list( + functions = NULL, initialize = function(stan_file = NULL, exe_file = NULL, compile, ...) { args <- list(...) private$dir_ <- args$dir + self$functions <- new.env() + self$functions$compiled <- FALSE if (!is.null(stan_file)) { checkmate::assert_file_exists(stan_file, access = "r", extension = "stan") checkmate::assert_flag(compile) @@ -323,6 +325,10 @@ CmdStanModel <- R6::R6Class( "- ", new_hpp_loc) private$hpp_file_ <- new_hpp_loc invisible(private$hpp_file_) + }, + expose_functions = function(global = FALSE, verbose = FALSE) { + expose_functions(self$functions, global, verbose) + invisible(NULL) } ) ) @@ -561,11 +567,9 @@ compile <- function(quiet = TRUE, } } stancflags_standalone <- c("--standalone-functions", stancflags_val, stanc_built_options) - private$standalone_env_ <- new.env() - private$standalone_env_$compiled <- FALSE - private$standalone_env_$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone) + self$functions$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone) if (compile_standalone) { - expose_functions(private$standalone_env_, !quiet) + expose_functions(self$functions, !quiet) } stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stanc_built_options, collapse = " ")) withr::with_path( @@ -1127,7 +1131,7 @@ sample <- function(data = NULL, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE), @@ -1286,7 +1290,7 @@ sample_mpi <- function(data = NULL, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE), @@ -1399,7 +1403,7 @@ optimize <- function(data = NULL, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = 1, @@ -1518,7 +1522,7 @@ variational <- function(data = NULL, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = 1, @@ -1636,7 +1640,7 @@ generate_quantities <- function(fitted_params, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = seq_along(fitted_params_files), @@ -1701,7 +1705,7 @@ diagnose <- function(data = NULL, stan_file = self$stan_file(), stan_code = suppressWarnings(self$code()), model_methods_env = private$model_methods_env_, - standalone_env = private$standalone_env_, + standalone_env = self$functions, model_name = self$model_name(), exe_file = self$exe_file(), proc_ids = 1, diff --git a/R/utils.R b/R/utils.R index c02e9fe17..c1b3c09a0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -675,7 +675,7 @@ prep_fun_cpp <- function(fun_body, model_lines) { gsub(pattern = ",\\s*)", replacement = ")", fun_body) } -expose_functions <- function(env, verbose = FALSE, global = FALSE) { +compile_functions <- function(env, verbose = FALSE, global = FALSE) { funs <- grep("// [[stan::function]]", env$hpp_code, fixed = TRUE) funs <- c(funs, length(env$hpp_code)) @@ -702,3 +702,26 @@ expose_functions <- function(env, verbose = FALSE, global = FALSE) { env$compiled <- TRUE invisible(NULL) } + +expose_functions <- function(function_env, global = FALSE, verbose = FALSE) { + require_suggested_package("Rcpp") + require_suggested_package("RcppEigen") + require_suggested_package("decor") + if (function_env$compiled) { + if (!global) { + message("Functions already compiled, nothing to do!") + } else { + message("Functions already compiled, copying to global environment") + # Create reference to global environment, avoids NOTE about assigning to global + pos <- 1 + envir = as.environment(pos) + lapply(function_env$fun_names, function(fun_name) { + assign(fun_name, get(fun_name, function_env), envir) + }) + } + } else { + message("Compiling standalone functions...") + compile_functions(function_env, verbose, global) + } + invisible(NULL) +} \ No newline at end of file diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 6a27bf133..d1e5a820f 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -12,6 +12,23 @@ data_list <- testing_data("bernoulli") mod <- cmdstan_model(model, force_recompile = TRUE) fit <- mod$sample(data = data_list) + +test_that("Functions can be exposed in model object", { + mod$expose_functions(verbose = TRUE) + + expect_equal( + fit$functions$retvec(c(1,2,3,4)), + c(1,2,3,4) + ) + + mod$expose_functions(global = TRUE, verbose = TRUE) + + expect_equal( + retvec(c(1,2,3,4)), + c(1,2,3,4) + ) +}) + test_that("Functions can be exposed in fit object", { fit$expose_functions(verbose = TRUE)