Skip to content

Commit

Permalink
Allow exposing functions without sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Oct 16, 2022
1 parent cdd62a0 commit 37cc98f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 37 deletions.
28 changes: 4 additions & 24 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ CmdStanFit <- R6::R6Class(
"rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)\n")
}
invisible(self)
},
expose_functions = function(global = FALSE, verbose = FALSE) {
expose_functions(self$functions, global, verbose)
invisible(NULL)
}
),
private = list(
Expand Down Expand Up @@ -280,30 +284,6 @@ init <- function() {
}
CmdStanFit$set("public", name = "init", value = init)

expose_functions <- function(global = FALSE, verbose = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (self$functions$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
} else {
message("Functions already compiled, copying to global environment")
# Create reference to global environment, avoids NOTE about assigning to global
pos <- 1
envir = as.environment(pos)
lapply(self$functions$fun_names, function(fun_name) {
assign(fun_name, get(fun_name, self$functions), envir)
})
}
} else {
message("Compiling standalone functions...")
expose_functions(self$functions, verbose, global)
}
invisible(NULL)
}
CmdStanFit$set("public", name = "expose_functions", value = expose_functions)

#' Compile additional methods for accessing the model log-probability function
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
#'
Expand Down
28 changes: 16 additions & 12 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,15 @@ CmdStanModel <- R6::R6Class(
precompile_cpp_options_ = NULL,
precompile_stanc_options_ = NULL,
precompile_include_paths_ = NULL,
variables_ = NULL,
standalone_env_ = NULL
variables_ = NULL
),
public = list(
functions = NULL,
initialize = function(stan_file = NULL, exe_file = NULL, compile, ...) {
args <- list(...)
private$dir_ <- args$dir
self$functions <- new.env()
self$functions$compiled <- FALSE
if (!is.null(stan_file)) {
checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")
checkmate::assert_flag(compile)
Expand Down Expand Up @@ -323,6 +325,10 @@ CmdStanModel <- R6::R6Class(
"- ", new_hpp_loc)
private$hpp_file_ <- new_hpp_loc
invisible(private$hpp_file_)
},
expose_functions = function(global = FALSE, verbose = FALSE) {
expose_functions(self$functions, global, verbose)
invisible(NULL)
}
)
)
Expand Down Expand Up @@ -561,11 +567,9 @@ compile <- function(quiet = TRUE,
}
}
stancflags_standalone <- c("--standalone-functions", stancflags_val, stanc_built_options)
private$standalone_env_ <- new.env()
private$standalone_env_$compiled <- FALSE
private$standalone_env_$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
self$functions$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
if (compile_standalone) {
expose_functions(private$standalone_env_, !quiet)
expose_functions(self$functions, !quiet)
}
stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stanc_built_options, collapse = " "))
withr::with_path(
Expand Down Expand Up @@ -1127,7 +1131,7 @@ sample <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1286,7 +1290,7 @@ sample_mpi <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1399,7 +1403,7 @@ optimize <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1518,7 +1522,7 @@ variational <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1636,7 +1640,7 @@ generate_quantities <- function(fitted_params,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
Expand Down Expand Up @@ -1701,7 +1705,7 @@ diagnose <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
standalone_env = self$functions,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down
25 changes: 24 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ prep_fun_cpp <- function(fun_body, model_lines) {
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
}

expose_functions <- function(env, verbose = FALSE, global = FALSE) {
compile_functions <- function(env, verbose = FALSE, global = FALSE) {
funs <- grep("// [[stan::function]]", env$hpp_code, fixed = TRUE)
funs <- c(funs, length(env$hpp_code))

Expand All @@ -702,3 +702,26 @@ expose_functions <- function(env, verbose = FALSE, global = FALSE) {
env$compiled <- TRUE
invisible(NULL)
}

expose_functions <- function(function_env, global = FALSE, verbose = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (function_env$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
} else {
message("Functions already compiled, copying to global environment")
# Create reference to global environment, avoids NOTE about assigning to global
pos <- 1
envir = as.environment(pos)
lapply(function_env$fun_names, function(fun_name) {
assign(fun_name, get(fun_name, function_env), envir)
})
}
} else {
message("Compiling standalone functions...")
compile_functions(function_env, verbose, global)
}
invisible(NULL)
}
17 changes: 17 additions & 0 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)


test_that("Functions can be exposed in model object", {
mod$expose_functions(verbose = TRUE)

expect_equal(
fit$functions$retvec(c(1,2,3,4)),
c(1,2,3,4)
)

mod$expose_functions(global = TRUE, verbose = TRUE)

expect_equal(
retvec(c(1,2,3,4)),
c(1,2,3,4)
)
})

test_that("Functions can be exposed in fit object", {
fit$expose_functions(verbose = TRUE)

Expand Down

0 comments on commit 37cc98f

Please sign in to comment.