Skip to content

Commit

Permalink
Merge pull request #702 from andrjohns/expose-stan-functions
Browse files Browse the repository at this point in the history
Add optional method for exposing stan functions to R
  • Loading branch information
rok-cesnovar authored Oct 12, 2022
2 parents f01934c + b8707ef commit cdd62a0
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 15 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ Suggests:
rmarkdown,
testthat (>= 2.1.0),
Rcpp,
RcppEigen
RcppEigen,
decor
VignetteBuilder: knitr
2 changes: 2 additions & 0 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ CmdStanArgs <- R6::R6Class(
stan_file = NULL,
stan_code = NULL,
model_methods_env = NULL,
standalone_env = NULL,
exe_file,
proc_ids,
method_args,
Expand All @@ -45,6 +46,7 @@ CmdStanArgs <- R6::R6Class(
self$stan_code <- stan_code
self$exe_file <- exe_file
self$model_methods_env <- model_methods_env
self$standalone_env <- standalone_env
self$proc_ids <- proc_ids
self$data_file <- data_file
self$seed <- seed
Expand Down
26 changes: 26 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ CmdStanFit <- R6::R6Class(
classname = "CmdStanFit",
public = list(
runset = NULL,
functions = NULL,
initialize = function(runset) {
checkmate::assert_r6(runset, classes = "CmdStanRun")
self$runset <- runset
private$model_methods_env_ <- runset$model_methods_env()
self$functions <- runset$standalone_env()

if (!is.null(private$model_methods_env_$model_ptr)) {
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
Expand Down Expand Up @@ -278,6 +280,30 @@ init <- function() {
}
CmdStanFit$set("public", name = "init", value = init)

expose_functions <- function(global = FALSE, verbose = FALSE) {
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (self$functions$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
} else {
message("Functions already compiled, copying to global environment")
# Create reference to global environment, avoids NOTE about assigning to global
pos <- 1
envir = as.environment(pos)
lapply(self$functions$fun_names, function(fun_name) {
assign(fun_name, get(fun_name, self$functions), envir)
})
}
} else {
message("Compiling standalone functions...")
expose_functions(self$functions, verbose, global)
}
invisible(NULL)
}
CmdStanFit$set("public", name = "expose_functions", value = expose_functions)

#' Compile additional methods for accessing the model log-probability function
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
#'
Expand Down
18 changes: 17 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ CmdStanModel <- R6::R6Class(
precompile_cpp_options_ = NULL,
precompile_stanc_options_ = NULL,
precompile_include_paths_ = NULL,
variables_ = NULL
variables_ = NULL,
standalone_env_ = NULL
),
public = list(
initialize = function(stan_file = NULL, exe_file = NULL, compile, ...) {
Expand Down Expand Up @@ -387,6 +388,7 @@ CmdStanModel <- R6::R6Class(
#' (`log_prob()`, `grad_log_prob()`, `constrain_pars()`, `unconstrain_pars()`)
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
#' be compiled with the model methods?
#' @param compile_standalone (logical) Should functions in the Stan model be compiled for used in R?
#' @param threads Deprecated and will be removed in a future release. Please
#' turn on threading via `cpp_options = list(stan_threads = TRUE)` instead.
#'
Expand Down Expand Up @@ -438,6 +440,7 @@ compile <- function(quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
compile_hessian_method = FALSE,
compile_standalone = FALSE,
#deprecated
threads = FALSE) {
if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -557,6 +560,13 @@ compile <- function(quiet = TRUE,
stanc_built_options <- c(stanc_built_options, paste0("--", option_name, "=", "'", stanc_options[[i]], "'"))
}
}
stancflags_standalone <- c("--standalone-functions", stancflags_val, stanc_built_options)
private$standalone_env_ <- new.env()
private$standalone_env_$compiled <- FALSE
private$standalone_env_$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
if (compile_standalone) {
expose_functions(private$standalone_env_, !quiet)
}
stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stanc_built_options, collapse = " "))
withr::with_path(
c(
Expand Down Expand Up @@ -1117,6 +1127,7 @@ sample <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1275,6 +1286,7 @@ sample_mpi <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
Expand Down Expand Up @@ -1387,6 +1399,7 @@ optimize <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1505,6 +1518,7 @@ variational <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down Expand Up @@ -1622,6 +1636,7 @@ generate_quantities <- function(fitted_params,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
Expand Down Expand Up @@ -1686,6 +1701,7 @@ diagnose <- function(data = NULL,
stan_file = self$stan_file(),
stan_code = suppressWarnings(self$code()),
model_methods_env = private$model_methods_env_,
standalone_env = private$standalone_env_,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
Expand Down
1 change: 1 addition & 0 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ CmdStanRun <- R6::R6Class(
exe_file = function() self$args$exe_file,
stan_code = function() self$args$stan_code,
model_methods_env = function() self$args$model_methods_env,
standalone_env = function() self$args$standalone_env,
model_name = function() self$args$model_name,
method = function() self$args$method,
data_file = function() self$args$data_file,
Expand Down
119 changes: 106 additions & 13 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ as_mcmc.list <- function(x) {
return(mcmc_list)
}

# Model methods & expose_functions helpers ------------------------------------------------------
get_cmdstan_flags <- function(flag_name) {
cmdstan_path <- cmdstanr::cmdstan_path()
flags <- processx::run(
Expand Down Expand Up @@ -557,19 +558,7 @@ get_cmdstan_flags <- function(flag_name) {
paste(flags, collapse = " ")
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))

if (hessian) {
code <- c(code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
}

code <- paste(code, collapse = "\n")

rcpp_source_stan <- function(code, env, verbose = FALSE) {
cxxflags <- get_cmdstan_flags("CXXFLAGS")
libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = "")
Expand All @@ -592,6 +581,22 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
invisible(NULL)
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))

if (hessian) {
code <- c(code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
}

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
invisible(NULL)
}

initialize_model_pointer <- function(env, data, seed = 0) {
ptr_and_rng <- env$model_ptr(data, seed)
env$model_ptr_ <- ptr_and_rng$model_ptr
Expand All @@ -609,3 +614,91 @@ create_skeleton <- function(model_variables) {
})
stats::setNames(skeleton, names(model_pars))
}

get_standalone_hpp <- function(stan_file, stancflags) {
status <- withr::with_path(
c(
toolchain_PATH_env_var(),
tbb_path()
),
wsl_compatible_run(
command = stanc_cmd(),
args = c(stan_file,
stancflags),
wd = cmdstan_path(),
error_on_status = FALSE
)
)
if (status$status == 0) {
name <- strip_ext(basename(stan_file))
path <- dirname(stan_file)
hpp_path <- file.path(path, paste0(name, ".hpp"))
hpp <- readLines(hpp_path)
unlink(hpp_path)
hpp
} else {
invisible(NULL)
}
}

# Construct the plain return type for a standalone function by
# looking up the return type of the functor declaration and replacing
# the template types (i.e., T0__) with double
get_plain_rtn <- function(fun_body, model_lines) {
fun_props <- decor::parse_cpp_function(paste(fun_body[-1], collapse = "\n"))
struct_start <- grep(paste0("struct ", fun_props$name, "_functor"), model_lines)
struct_op_start <- grep("operator()", model_lines[-(1:struct_start)])[1] + struct_start

struct_rtn <- grep("nullptr>", model_lines[struct_start:struct_op_start], fixed = TRUE) + struct_start

rtn_type <- paste0(model_lines[struct_rtn:struct_op_start], collapse = " ")
rm_trailing_nullptr <- gsub(".*nullptr>[^,]", "", rtn_type)
rm_operator <- gsub("operator().*", "", rtn_type)
repl_dbl <- gsub("T[0-9*]__", "double", rm_operator)
gsub("(^\\s|\\s$)", "", repl_dbl)
}

# Prepare the c++ code for a standalone function so that it can be exported to R:
# - Replace the auto return type with the plain type
# - Add Rcpp::export attribute
# - Remove the pstream__ argument and pass Rcpp::Rcout by default
# - Replace the boost::ecuyer1988& base_rng__ argument with an integer seed argument
# that instantiates an RNG
prep_fun_cpp <- function(fun_body, model_lines) {
fun_body <- gsub("auto", get_plain_rtn(fun_body, model_lines), fun_body)
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream* pstream__ = nullptr", "", fun_body, fixed = TRUE)
fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE)
fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
fun_body <- paste(fun_body, collapse = "\n")
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
}

expose_functions <- function(env, verbose = FALSE, global = FALSE) {
funs <- grep("// [[stan::function]]", env$hpp_code, fixed = TRUE)
funs <- c(funs, length(env$hpp_code))

stan_funs <- sapply(seq_len(length(funs) - 1), function(ind) {
fun_body <- env$hpp_code[funs[ind]:(funs[ind + 1] - 1)]
prep_fun_cpp(fun_body, env$hpp_code)
})

env$fun_names <- sapply(stan_funs, function(fun) {
decor::parse_cpp_function(fun, is_attribute = TRUE)$name
})

mod_stan_funs <- paste(c(
env$hpp_code[1:(funs[1] - 1)],
"#include <RcppEigen.h>",
"// [[Rcpp::depends(RcppEigen)]]",
stan_funs),
collapse = "\n")
if (global) {
rcpp_source_stan(mod_stan_funs, globalenv(), verbose)
} else {
rcpp_source_stan(mod_stan_funs, env, verbose)
}
env$compiled <- TRUE
invisible(NULL)
}
3 changes: 3 additions & 0 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit cdd62a0

Please sign in to comment.