Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix expose_functions under 2.32 #769

Merged
merged 5 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ jobs:
config:
- {os: macOS-latest, r: 'devel', rtools: ''}
- {os: macOS-latest, r: 'release', rtools: ''}
- {os: windows-latest, r: 'devel', rtools: '43'}
- {os: windows-latest, r: 'release', rtools: '42'}
- {os: windows-latest, r: 'oldrel', rtools: '40'}
- {os: windows-latest, r: 'devel', rtools: ''}
- {os: windows-latest, r: 'release', rtools: ''}
- {os: windows-latest, r: 'oldrel', rtools: '42'}
- {os: ubuntu-20.04, r: 'devel', rtools: ''}
- {os: ubuntu-20.04, r: 'release', rtools: ''}
- {os: ubuntu-20.04, r: 'oldrel', rtools: ''}
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true

- uses: r-lib/actions/setup-r@v2.4.0
- uses: r-lib/actions/setup-r@v2.6.3
with:
r-version: ${{ matrix.config.r }}
rtools-version: ${{ matrix.config.rtools }}
Expand Down
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,5 @@ Suggests:
rmarkdown,
testthat (>= 2.1.0),
Rcpp,
RcppEigen,
decor
RcppEigen
VignetteBuilder: knitr
2 changes: 1 addition & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ compile <- function(quiet = TRUE,
private$precompile_stanc_options_ <- NULL
private$precompile_include_paths_ <- NULL
private$model_methods_env_ <- new.env()
private$model_methods_env_$hpp_code_ <- readLines(private$hpp_file_, warn = FALSE)
suppressWarnings(private$model_methods_env_$hpp_code_ <- readLines(private$hpp_file_, warn = FALSE))
if (compile_model_methods) {
expose_model_methods(env = private$model_methods_env_,
verbose = !quiet,
Expand Down
66 changes: 46 additions & 20 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -797,41 +797,67 @@ get_standalone_hpp <- function(stan_file, stancflags) {
name <- strip_ext(basename(stan_file))
path <- dirname(stan_file)
hpp_path <- file.path(path, paste0(name, ".hpp"))
hpp <- readLines(hpp_path, warn = FALSE)
hpp <- suppressWarnings(readLines(hpp_path, warn = FALSE))
unlink(hpp_path)
hpp
} else {
invisible(NULL)
}
}

get_function_name <- function(fun_start, fun_end, model_lines) {
fun_string <- paste(model_lines[(fun_start+1):fun_end], collapse = " ")
fun_name <- gsub("auto ", "", fun_string, fixed = TRUE)
sub("\\(.*", "", fun_name, perl = TRUE)
}

# Construct the plain return type for a standalone function by
# looking up the return type of the functor declaration and replacing
# the template types (i.e., T0__) with double
get_plain_rtn <- function(fun_body, model_lines) {
fun_string <- paste(fun_body[-1], collapse = "\n")
fun_props <- decor::parse_cpp_function(fun_string)

struct_start <- grep(paste0("struct ", fun_props$name, "_functor"), model_lines)
struct_op_start <- grep("operator()", model_lines[-(1:struct_start)])[1] + struct_start
rtn_type <- paste0(model_lines[struct_start:struct_op_start], collapse = " ")

rm_operator <- gsub("operator().*", "", rtn_type)
rm_prev <- gsub(".*\\{", "", rm_operator)
get_plain_rtn <- function(fun_start, fun_end, model_lines) {
fun_name <- get_function_name(fun_start, fun_end, model_lines)

# Depending on the version of stanc3, the standalone functions
# with a plain return type can either be wrapped in a struct as a functor,
# or as a separate forward declaration
struct_name <- paste0("struct ", fun_name, "_functor")

if (any(grepl(struct_name, model_lines))) {
struct_start <- grep(struct_name, model_lines)
struct_op_start <- grep("operator()", model_lines[-(1:struct_start)])[1] + struct_start
rtn_type <- paste0(model_lines[struct_start:struct_op_start], collapse = " ")
rm_operator <- gsub("operator().*", "", rtn_type)
rm_prev <- gsub(".*\\{", "", rm_operator)
} else {
# Find first declaration of function (will be the forward declaration)
first_decl <- grep(paste0(fun_name,"\\("), model_lines)[1]

# The return type will be between the function name and the semicolon terminating
# the previous line
last_scolon <- grep(";", model_lines[1:first_decl])
last_scolon <- ifelse(last_scolon[length(last_scolon)] == first_decl,
last_scolon[length(last_scolon) - 1],
last_scolon[length(last_scolon)])
rtn_type_full <- paste0(model_lines[last_scolon:first_decl], collapse = " ")
rm_fun_name <- gsub(paste0(fun_name, ".*"), "", rtn_type_full)
rm_prev <- gsub(".*;", "", rm_fun_name)
}
rm_template <- gsub("template <typename(.*?)> ", "", rm_prev)
gsub("T([0-9])*__", "double", rm_template)
}


# Prepare the c++ code for a standalone function so that it can be exported to R:
# - Replace the auto return type with the plain type
# - Add Rcpp::export attribute
# - Remove the pstream__ argument and pass Rcpp::Rcout by default
# - Replace the boost::ecuyer1988& base_rng__ argument with an integer seed argument
# that instantiates an RNG
prep_fun_cpp <- function(fun_body, model_lines) {
fun_body <- gsub("auto", get_plain_rtn(fun_body, model_lines), fun_body)
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream* pstream__ = nullptr", "", fun_body, fixed = TRUE)
prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
fun_body <- paste(model_lines[fun_start:fun_end], collapse = " ")
fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body)
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE)
fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
Expand All @@ -844,12 +870,13 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
funs <- c(funs, length(env$hpp_code))

stan_funs <- sapply(seq_len(length(funs) - 1), function(ind) {
fun_body <- env$hpp_code[funs[ind]:(funs[ind + 1] - 1)]
prep_fun_cpp(fun_body, env$hpp_code)
fun_end <- funs[ind + 1]
fun_end <- ifelse(env$hpp_code[fun_end] == "}", fun_end, fun_end - 1)
prep_fun_cpp(funs[ind], fun_end, env$hpp_code)
})

env$fun_names <- sapply(stan_funs, function(fun) {
decor::parse_cpp_function(fun, is_attribute = TRUE)$name
env$fun_names <- sapply(seq_len(length(funs) - 1), function(ind) {
get_function_name(funs[ind], funs[ind + 1], env$hpp_code)
})

dups <- env$fun_names[duplicated(env$fun_names)]
Expand Down Expand Up @@ -888,7 +915,6 @@ expose_functions <- function(function_env, global = FALSE, verbose = FALSE) {
}
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
require_suggested_package("decor")
if (function_env$compiled) {
if (!global) {
message("Functions already compiled, nothing to do!")
Expand Down
65 changes: 25 additions & 40 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,38 @@ context("model-expose-functions")

set_cmdstan_path()

if (cmdstan_version() < "2.32.0") {
function_decl <- "
functions {
int rtn_int(int x) { return x; }
real rtn_real(real x) { return x; }
vector rtn_vec(vector x) { return x; }
row_vector rtn_rowvec(row_vector x) { return x; }
matrix rtn_matrix(matrix x) { return x; }

array[] int rtn_int_array(array[] int x) { return x; }
array[] real rtn_real_array(array[] real x) { return x; }
array[] vector rtn_vec_array(array[] vector x) { return x; }
array[] row_vector rtn_rowvec_array(array[] row_vector x) { return x; }
array[] matrix rtn_matrix_array(array[] matrix x) { return x; }
}"
stan_prog <- paste(function_decl,
paste(readLines(testing_stan_file("bernoulli")),
collapse = "\n"),
collapse = "\n")
model <- write_stan_file(stan_prog)
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)
}
function_decl <- "
functions {
int rtn_int(int x) { return x; }
real rtn_real(real x) { return x; }
vector rtn_vec(vector x) { return x; }
row_vector rtn_rowvec(row_vector x) { return x; }
matrix rtn_matrix(matrix x) { return x; }

array[] int rtn_int_array(array[] int x) { return x; }
array[] real rtn_real_array(array[] real x) { return x; }
array[] vector rtn_vec_array(array[] vector x) { return x; }
array[] row_vector rtn_rowvec_array(array[] row_vector x) { return x; }
array[] matrix rtn_matrix_array(array[] matrix x) { return x; }
}"
stan_prog <- paste(function_decl,
paste(readLines(testing_stan_file("bernoulli")),
collapse = "\n"),
collapse = "\n")
model <- write_stan_file(stan_prog)
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)


test_that("Functions can be exposed in model object", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")
expect_no_error(mod$expose_functions(verbose = TRUE))
})


test_that("Functions handle types correctly", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")

expect_equal(mod$functions$rtn_int(10), 10)
expect_equal(mod$functions$rtn_real(1.67), 1.67)
Expand All @@ -63,7 +59,6 @@ test_that("Functions handle types correctly", {

test_that("Functions can be exposed in fit object", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")
fit$expose_functions(verbose = TRUE)

expect_equal(
Expand All @@ -74,7 +69,6 @@ test_that("Functions can be exposed in fit object", {

test_that("Compiled functions can be copied to global environment", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")
expect_message(
fit$expose_functions(global = TRUE),
"Functions already compiled, copying to global environment",
Expand All @@ -90,7 +84,6 @@ test_that("Compiled functions can be copied to global environment", {

test_that("Functions can be compiled with model", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")
mod <- cmdstan_model(model, force_recompile = TRUE, compile_standalone = TRUE)
fit <- mod$sample(data = data_list)

Expand Down Expand Up @@ -119,7 +112,6 @@ test_that("Functions can be compiled with model", {

test_that("rng functions can be exposed", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")
function_decl <- "functions { real normal_rng(real mu) { return normal_rng(mu, 1); } }"
stan_prog <- paste(function_decl,
paste(readLines(testing_stan_file("bernoulli")),
Expand All @@ -140,7 +132,6 @@ test_that("rng functions can be exposed", {

test_that("Overloaded functions give meaningful errors", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")

funcode <- "
functions {
Expand All @@ -159,16 +150,8 @@ test_that("Overloaded functions give meaningful errors", {

test_that("Exposing external functions errors before v2.32", {
skip_if(os_is_wsl())
skip_if(cmdstan_version() >= "2.32.0")

if (getRversion() < '3.5.0') {
dir <- tempdir()
} else {
dir <- tempdir(check = TRUE)
}
install_cmdstan(dir = dir, cores = 2, quiet = FALSE,
overwrite = TRUE, version = "2.31.0",
wsl = os_is_wsl())
fake_cmdstan_version("2.26.0")

tmpfile <- tempfile(fileext = ".hpp")
hpp <-
Expand All @@ -189,4 +172,6 @@ test_that("Exposing external functions errors before v2.32", {
},
"Exporting standalone functions with external C++ is not available before CmdStan 2.32",
fixed = TRUE)

reset_cmdstan_version()
})