Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exposing functions without sampling #705

Merged
merged 1 commit into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ CmdStanFit <- R6::R6Class(
"rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)\n")
}
invisible(self)
},
expose_functions = function(global = FALSE, verbose = FALSE) {
expose_functions(self$functions, global, verbose)
invisible(NULL)
}
),
private = list(
Expand Down Expand Up @@ -280,30 +284,6 @@ init <- function() {
}
CmdStanFit$set("public", name = "init", value = init)

expose_functions <- function(global = FALSE, verbose = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (self$functions$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
} else {
message("Functions already compiled, copying to global environment")
# Create reference to global environment, avoids NOTE about assigning to global
pos <- 1
envir = as.environment(pos)
lapply(self$functions$fun_names, function(fun_name) {
assign(fun_name, get(fun_name, self$functions), envir)
})
}
} else {
message("Compiling standalone functions...")
expose_functions(self$functions, verbose, global)
}
invisible(NULL)
}
CmdStanFit$set("public", name = "expose_functions", value = expose_functions)

#' Compile additional methods for accessing the model log-probability function
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
#'
Expand Down
28 changes: 16 additions & 12 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,15 @@ CmdStanModel <- R6::R6Class(
precompile_cpp_options_ = NULL,
precompile_stanc_options_ = NULL,
precompile_include_paths_ = NULL,
variables_ = NULL,
standalone_env_ = NULL
variables_ = NULL
),
public = list(
functions = NULL,
initialize = function(stan_file = NULL, exe_file = NULL, compile, ...) {
args <- list(...)
private$dir_ <- args$dir
self$functions <- new.env()
self$functions$compiled <- FALSE
if (!is.null(stan_file)) {
checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")
checkmate::assert_flag(compile)
Expand Down Expand Up @@ -323,6 +325,10 @@ CmdStanModel <- R6::R6Class(
"- ", new_hpp_loc)
private$hpp_file_ <- new_hpp_loc
invisible(private$hpp_file_)
},
expose_functions = function(global = FALSE, verbose = FALSE) {
expose_functions(self$functions, global, verbose)
invisible(NULL)
}
)
)
Expand Down Expand Up @@ -561,11 +567,9 @@ compile <- function(quiet = TRUE,
}
}
stancflags_standalone <- c("--standalone-functions", stancflags_val, stanc_built_options)
private$standalone_env_ <- new.env()
private$standalone_env_$compiled <- FALSE
private$standalone_env_$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
self$functions$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
if (compile_standalone) {
expose_functions(private$standalone_env_, !quiet)
expose_functions(self$functions, !quiet)
}
stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stanc_built_options, collapse = " "))
withr::with_path(
Expand Down Expand Up @@ -1127,7 +1131,7 @@ sample <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1286,7 +1290,7 @@ sample_mpi <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1399,7 +1403,7 @@ optimize <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1518,7 +1522,7 @@ variational <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1636,7 +1640,7 @@ generate_quantities <- function(fitted_params,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
Expand Down Expand Up @@ -1701,7 +1705,7 @@ diagnose <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down
25 changes: 24 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ prep_fun_cpp <- function(fun_body, model_lines) {
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
}

expose_functions <- function(env, verbose = FALSE, global = FALSE) {
compile_functions <- function(env, verbose = FALSE, global = FALSE) {
funs <- grep("// [[stan::function]]", env$hpp_code, fixed = TRUE)
funs <- c(funs, length(env$hpp_code))

Expand All @@ -702,3 +702,26 @@ expose_functions <- function(env, verbose = FALSE, global = FALSE) {
env$compiled <- TRUE
invisible(NULL)
}

expose_functions <- function(function_env, global = FALSE, verbose = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (function_env$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
} else {
message("Functions already compiled, copying to global environment")
# Create reference to global environment, avoids NOTE about assigning to global
pos <- 1
envir = as.environment(pos)
lapply(function_env$fun_names, function(fun_name) {
assign(fun_name, get(fun_name, function_env), envir)
})
}
} else {
message("Compiling standalone functions...")
compile_functions(function_env, verbose, global)
}
invisible(NULL)
}
17 changes: 17 additions & 0 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)


test_that("Functions can be exposed in model object", {
mod$expose_functions(verbose = TRUE)

expect_equal(
fit$functions$retvec(c(1,2,3,4)),
c(1,2,3,4)
)

mod$expose_functions(global = TRUE, verbose = TRUE)

expect_equal(
retvec(c(1,2,3,4)),
c(1,2,3,4)
)
})

test_that("Functions can be exposed in fit object", {
fit$expose_functions(verbose = TRUE)

Expand Down