Skip to content

Commit

Permalink
Merge pull request #704 from andrjohns/constrain-return-all
Browse files Browse the repository at this point in the history
Make constrain_pars behaviour consistent with rstan
  • Loading branch information
rok-cesnovar authored Oct 16, 2022
2 parents cdd62a0 + 4ecfd85 commit 9aa36ee
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 28 deletions.
29 changes: 18 additions & 11 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
if (hessian) {
message("The hessian method relies on higher-order autodiff ",
"which is still experimental. Please report any compilation ",
"errors that you encounter",
call. = FALSE)
"errors that you encounter")
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
Expand Down Expand Up @@ -459,18 +458,15 @@ unconstrain_pars <- function(pars) {
model_par_names <- names(self$runset$args$model_variables$parameters)
prov_par_names <- names(pars)

prov_pars_not_in_model <- which(!(prov_par_names %in% model_par_names))
if (length(prov_pars_not_in_model) > 0) {
stop("Provided parameter(s): ", paste(prov_par_names[prov_pars_not_in_model], collapse = ","),
" not present in model!", call. = FALSE)
}

model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
if (length(model_pars_not_prov) > 0) {
stop("Model parameter(s): ", paste(model_par_names[model_pars_not_prov], collapse = ","),
" not provided!", call. = FALSE)
}

# Ignore extraneous parameters
model_pars_only <- pars[model_par_names]

stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables)
private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars)
}
Expand All @@ -484,14 +480,18 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars)
#' the constrained scale
#'
#' @param upars (numeric) A vector of unconstrained parameters to constrain
#' @param transformed_parameters (boolean) Whether to return transformed parameters
#' implied by newly-constrained parameters (defaults to TRUE)
#' @param generated_quantities (boolean) Whether to return generated quantities
#' implied by newly-constrained parameters (defaults to TRUE)
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1))
#' }
#'
constrain_pars <- function(upars) {
constrain_pars <- function(upars, transformed_parameters = TRUE, generated_quantities = TRUE) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
Expand All @@ -500,8 +500,15 @@ constrain_pars <- function(upars) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(upars), " were provided!", call. = FALSE)
}
cpars <- private$model_methods_env_$constrain_pars(private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars)
skeleton <- create_skeleton(self$runset$args$model_variables)
cpars <- private$model_methods_env_$constrain_pars(
private$model_methods_env_$model_ptr_,
private$model_methods_env_$model_rng_,
upars, transformed_parameters, generated_quantities)

skeleton <- create_skeleton(private$model_methods_env_$param_metadata_,
self$runset$args$model_variables,
transformed_parameters,
generated_quantities)
utils::relist(cpars, skeleton)
}
CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars)
Expand Down
21 changes: 14 additions & 7 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -602,17 +602,24 @@ initialize_model_pointer <- function(env, data, seed = 0) {
env$model_ptr_ <- ptr_and_rng$model_ptr
env$model_rng_ <- ptr_and_rng$base_rng
env$num_upars_ <- env$get_num_upars(env$model_ptr_)
env$param_metadata_ <- env$get_param_metadata(env$model_ptr_)
invisible(NULL)
}

create_skeleton <- function(model_variables) {
model_pars <- model_variables$parameters
skeleton <- lapply(model_pars, function(par) {
dims <- par$dimensions
dims <- ifelse(dims == 0, 1, dims)
array(0, dim = dims)
create_skeleton <- function(param_metadata, model_variables,
transformed_parameters, generated_quantities) {
target_params <- names(model_variables$parameters)
if (transformed_parameters) {
target_params <- c(target_params,
names(model_variables$transformed_parameters))
}
if (generated_quantities) {
target_params <- c(target_params,
names(model_variables$generated_quantities))
}
lapply(param_metadata[target_params], function(par_dims) {
array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims))
})
stats::setNames(skeleton, names(model_pars))
}

get_standalone_hpp <- function(stan_file, stancflags) {
Expand Down
25 changes: 23 additions & 2 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ size_t get_num_upars(SEXP ext_model_ptr) {
return ptr->num_params_r();
}

// [[Rcpp::export]]
Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<std::string> param_names;
std::vector<std::vector<size_t> > param_dims;
ptr->get_param_names(param_names);
ptr->get_dims(param_dims);

Rcpp::List param_metadata = Rcpp::List::create(
Rcpp::Named(param_names[0]) = param_dims[0]
);
for (size_t i = 1; i < param_names.size(); i++) {
param_metadata.push_back(param_dims[i], param_names[i]);
}

return param_metadata;
}

// [[Rcpp::export]]
std::vector<double> unconstrain_pars(SEXP ext_model_ptr, std::string init_path) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Expand All @@ -79,12 +97,15 @@ std::vector<double> unconstrain_pars(SEXP ext_model_ptr, std::string init_path)
}

// [[Rcpp::export]]
std::vector<double> constrain_pars(SEXP ext_model_ptr, SEXP base_rng, std::vector<double> upars) {
std::vector<double> constrain_pars(SEXP ext_model_ptr, SEXP base_rng,
std::vector<double> upars,
bool return_trans_pars,
bool return_gen_quants) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Rcpp::XPtr<boost::ecuyer1988> rng(base_rng);
std::vector<int> params_i;
std::vector<double> vars;

ptr->write_array(*rng.get(), upars, params_i, vars, false, false);
ptr->write_array(*rng.get(), upars, params_i, vars, return_trans_pars, return_gen_quants);
return vars;
}
12 changes: 11 additions & 1 deletion man/fit-method-constrain_pars.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 10 additions & 7 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
context("model-methods")

set_cmdstan_path()
mod <- cmdstan_model(testing_stan_file("bernoulli"), force_recompile = TRUE)
mod <- cmdstan_model(testing_stan_file("bernoulli_log_lik"), force_recompile = TRUE)
data_list <- testing_data("bernoulli")
fit <- mod$sample(data = data_list, chains = 1)

Expand Down Expand Up @@ -57,12 +57,20 @@ test_that("Methods return correct values", {
expect_equal(fit$hessian(upars=c(0.1)), hessian)

cpars <- fit$constrain_pars(c(0.1))
expect_equal(cpars, list(theta = 0.52497918747894001257))
cpars_true <- list(
theta = 0.52497918747894001257,
log_lik = rep(-7.2439666007357095268, data_list$N)
)
expect_equal(cpars, cpars_true)

expect_equal(fit$constrain_pars(c(0.1), generated_quantities = FALSE),
list(theta = 0.52497918747894001257))

upars <- fit$unconstrain_pars(cpars)
expect_equal(upars, c(0.1))
})


test_that("methods error for incorrect inputs", {
expect_error(
fit$log_prob(c(1,2)),
Expand All @@ -79,11 +87,6 @@ test_that("methods error for incorrect inputs", {
"Model has 1 unconstrained parameter(s), but 2 were provided!",
fixed = TRUE
)
expect_error(
fit$unconstrain_pars(list(theta = 0.5, dummy = 5)),
"Provided parameter(s): dummy not present in model!",
fixed = TRUE
)
expect_error(
fit$constrain_pars(c(1,2)),
"Model has 1 unconstrained parameter(s), but 2 were provided!",
Expand Down

0 comments on commit 9aa36ee

Please sign in to comment.