From 7c121e6d0e77c53dc01b0a8e9aa9576f7d06ea1c Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 16 Oct 2022 14:52:35 +0300 Subject: [PATCH 1/2] Make constrain_pars behaviour consistent with rstan --- R/fit.R | 26 +++++++++++++++++--------- R/utils.R | 21 ++++++++++++++------- inst/include/model_methods.cpp | 25 +++++++++++++++++++++++-- man/fit-method-constrain_pars.Rd | 12 +++++++++++- tests/testthat/test-model-methods.R | 17 ++++++++++------- 5 files changed, 75 insertions(+), 26 deletions(-) diff --git a/R/fit.R b/R/fit.R index 7f80f6d1b..eb1b4ab71 100644 --- a/R/fit.R +++ b/R/fit.R @@ -459,18 +459,15 @@ unconstrain_pars <- function(pars) { model_par_names <- names(self$runset$args$model_variables$parameters) prov_par_names <- names(pars) - prov_pars_not_in_model <- which(!(prov_par_names %in% model_par_names)) - if (length(prov_pars_not_in_model) > 0) { - stop("Provided parameter(s): ", paste(prov_par_names[prov_pars_not_in_model], collapse = ","), - " not present in model!", call. = FALSE) - } - model_pars_not_prov <- which(!(model_par_names %in% prov_par_names)) if (length(model_pars_not_prov) > 0) { stop("Model parameter(s): ", paste(model_par_names[model_pars_not_prov], collapse = ","), " not provided!", call. = FALSE) } + # Ignore extraneous parameters + model_pars_only <- pars[model_par_names] + stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables) private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars) } @@ -484,6 +481,10 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) #' the constrained scale #' #' @param upars (numeric) A vector of unconstrained parameters to constrain +#' @param transformed_parameters (boolean) Whether to return transformed parameters +#' implied by newly-constrained parameters (defaults to TRUE) +#' @param generated_quantities (boolean) Whether to return generated quantities +#' implied by newly-constrained parameters (defaults to TRUE) #' #' @examples #' \dontrun{ @@ -491,7 +492,7 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) #' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1)) #' } #' -constrain_pars <- function(upars) { +constrain_pars <- function(upars, transformed_parameters = TRUE, generated_quantities = TRUE) { if (is.null(private$model_methods_env_$model_ptr)) { stop("The method has not been compiled, please call `init_model_methods()` first", call. = FALSE) @@ -500,8 +501,15 @@ constrain_pars <- function(upars) { stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ", length(upars), " were provided!", call. = FALSE) } - cpars <- private$model_methods_env_$constrain_pars(private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars) - skeleton <- create_skeleton(self$runset$args$model_variables) + cpars <- private$model_methods_env_$constrain_pars( + private$model_methods_env_$model_ptr_, + private$model_methods_env_$model_rng_, + upars, transformed_parameters, generated_quantities) + + skeleton <- create_skeleton(private$model_methods_env_$param_metadata_, + self$runset$args$model_variables, + transformed_parameters, + generated_quantities) utils::relist(cpars, skeleton) } CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars) diff --git a/R/utils.R b/R/utils.R index c02e9fe17..17fb0b9f5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -602,17 +602,24 @@ initialize_model_pointer <- function(env, data, seed = 0) { env$model_ptr_ <- ptr_and_rng$model_ptr env$model_rng_ <- ptr_and_rng$base_rng env$num_upars_ <- env$get_num_upars(env$model_ptr_) + env$param_metadata_ <- env$get_param_metadata(env$model_ptr_) invisible(NULL) } -create_skeleton <- function(model_variables) { - model_pars <- model_variables$parameters - skeleton <- lapply(model_pars, function(par) { - dims <- par$dimensions - dims <- ifelse(dims == 0, 1, dims) - array(0, dim = dims) +create_skeleton <- function(param_metadata, model_variables, + transformed_parameters, generated_quantities) { + target_params <- names(model_variables$parameters) + if (transformed_parameters) { + target_params <- c(target_params, + names(model_variables$transformed_parameters)) + } + if (generated_quantities) { + target_params <- c(target_params, + names(model_variables$generated_quantities)) + } + lapply(param_metadata[target_params], function(par_dims) { + array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims)) }) - stats::setNames(skeleton, names(model_pars)) } get_standalone_hpp <- function(stan_file, stancflags) { diff --git a/inst/include/model_methods.cpp b/inst/include/model_methods.cpp index d1d622187..bff21fb64 100644 --- a/inst/include/model_methods.cpp +++ b/inst/include/model_methods.cpp @@ -69,6 +69,24 @@ size_t get_num_upars(SEXP ext_model_ptr) { return ptr->num_params_r(); } +// [[Rcpp::export]] +Rcpp::List get_param_metadata(SEXP ext_model_ptr) { + Rcpp::XPtr ptr(ext_model_ptr); + std::vector param_names; + std::vector > param_dims; + ptr->get_param_names(param_names); + ptr->get_dims(param_dims); + + Rcpp::List param_metadata = Rcpp::List::create( + Rcpp::Named(param_names[0]) = param_dims[0] + ); + for (size_t i = 1; i < param_names.size(); i++) { + param_metadata.push_back(param_dims[i], param_names[i]); + } + + return param_metadata; +} + // [[Rcpp::export]] std::vector unconstrain_pars(SEXP ext_model_ptr, std::string init_path) { Rcpp::XPtr ptr(ext_model_ptr); @@ -79,12 +97,15 @@ std::vector unconstrain_pars(SEXP ext_model_ptr, std::string init_path) } // [[Rcpp::export]] -std::vector constrain_pars(SEXP ext_model_ptr, SEXP base_rng, std::vector upars) { +std::vector constrain_pars(SEXP ext_model_ptr, SEXP base_rng, + std::vector upars, + bool return_trans_pars, + bool return_gen_quants) { Rcpp::XPtr ptr(ext_model_ptr); Rcpp::XPtr rng(base_rng); std::vector params_i; std::vector vars; - ptr->write_array(*rng.get(), upars, params_i, vars, false, false); + ptr->write_array(*rng.get(), upars, params_i, vars, return_trans_pars, return_gen_quants); return vars; } diff --git a/man/fit-method-constrain_pars.Rd b/man/fit-method-constrain_pars.Rd index 513b48c28..cc4e8f821 100644 --- a/man/fit-method-constrain_pars.Rd +++ b/man/fit-method-constrain_pars.Rd @@ -5,10 +5,20 @@ \alias{constrain_pars} \title{Transform a set of unconstrained parameter values to the constrained scale} \usage{ -constrain_pars(upars) +constrain_pars( + upars, + transformed_parameters = TRUE, + generated_quantities = TRUE +) } \arguments{ \item{upars}{(numeric) A vector of unconstrained parameters to constrain} + +\item{transformed_parameters}{(boolean) Whether to return transformed parameters +implied by newly-constrained parameters (defaults to TRUE)} + +\item{generated_quantities}{(boolean) Whether to return generated quantities +implied by newly-constrained parameters (defaults to TRUE)} } \description{ The \verb{$constrain_pars()} method transforms input parameters to diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index 115d34417..da27134e7 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -1,7 +1,7 @@ context("model-methods") set_cmdstan_path() -mod <- cmdstan_model(testing_stan_file("bernoulli"), force_recompile = TRUE) +mod <- cmdstan_model(testing_stan_file("bernoulli_log_lik"), force_recompile = TRUE) data_list <- testing_data("bernoulli") fit <- mod$sample(data = data_list, chains = 1) @@ -57,12 +57,20 @@ test_that("Methods return correct values", { expect_equal(fit$hessian(upars=c(0.1)), hessian) cpars <- fit$constrain_pars(c(0.1)) - expect_equal(cpars, list(theta = 0.52497918747894001257)) + cpars_true <- list( + theta = 0.52497918747894001257, + log_lik = rep(-7.2439666007357095268, data_list$N) + ) + expect_equal(cpars, cpars_true) + + expect_equal(fit$constrain_pars(c(0.1), generated_quantities = FALSE), + list(theta = 0.52497918747894001257)) upars <- fit$unconstrain_pars(cpars) expect_equal(upars, c(0.1)) }) + test_that("methods error for incorrect inputs", { expect_error( fit$log_prob(c(1,2)), @@ -79,11 +87,6 @@ test_that("methods error for incorrect inputs", { "Model has 1 unconstrained parameter(s), but 2 were provided!", fixed = TRUE ) - expect_error( - fit$unconstrain_pars(list(theta = 0.5, dummy = 5)), - "Provided parameter(s): dummy not present in model!", - fixed = TRUE - ) expect_error( fit$constrain_pars(c(1,2)), "Model has 1 unconstrained parameter(s), but 2 were provided!", From 4ecfd85cda3b9d41fea40410aa8fce049aaca794 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 16 Oct 2022 15:03:59 +0300 Subject: [PATCH 2/2] Extraneous call. --- R/fit.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/fit.R b/R/fit.R index eb1b4ab71..0bcb79ef3 100644 --- a/R/fit.R +++ b/R/fit.R @@ -332,8 +332,7 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) { if (hessian) { message("The hessian method relies on higher-order autodiff ", "which is still experimental. Please report any compilation ", - "errors that you encounter", - call. = FALSE) + "errors that you encounter") } message("Compiling additional model methods...") if (is.null(private$model_methods_env_$model_ptr)) {