Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose new base class method 'unconstrain_array' #3179

Merged
merged 10 commits into from
May 4, 2023
30 changes: 30 additions & 0 deletions src/stan/model/model_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,21 @@ class model_base : public prob_grad {
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const = 0;

/**
* Convert the specified sequence of constrained parameters to a
* sequence of unconstrained parameters.
*
* This is the inverse of write_array. The output will be resized
* if necessary to match the number of unconstrained parameters.
*
* @param[in] params_r_constrained constrained parameters input
* @param[in,out] params_r unconstrained parameters produced
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void unconstrain_array(const Eigen::VectorXd& params_constrained_r,
Eigen::VectorXd& params_r,
std::ostream* msgs = nullptr) const = 0;

// TODO(carpenter): cut redundant std::vector versions from here ===

/**
Expand Down Expand Up @@ -605,6 +620,21 @@ class model_base : public prob_grad {
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const = 0;

/**
* Convert the specified sequence of constrained parameters to a
* sequence of unconstrained parameters.
*
* This is the inverse of write_array. The output will be resized
* if necessary to match the number of unconstrained parameters.
*
* @param[in] params_r_constrained constrained parameters input
* @param[in,out] params_r unconstrained parameters produced
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void unconstrain_array(
const std::vector<double>& params_constrained_r,
std::vector<double>& params_r, std::ostream* msgs = nullptr) const = 0;

#ifdef STAN_MODEL_FVAR_VAR

/**
Expand Down
14 changes: 14 additions & 0 deletions src/stan/model/model_base_crtp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ class model_base_crtp : public stan::model::model_base {
rng, theta, vars, include_tparams, include_gqs, msgs);
}

void unconstrain_array(const Eigen::VectorXd& params_constrained_r,
Eigen::VectorXd& params_r,
std::ostream* msgs = nullptr) const override {
return static_cast<const M*>(this)->unconstrain_array(params_constrained_r,
params_r, msgs);
}

// TODO(carpenter): remove redundant std::vector methods below here =====
// ======================================================================

Expand Down Expand Up @@ -203,6 +210,13 @@ class model_base_crtp : public stan::model::model_base {
rng, theta, theta_i, vars, include_tparams, include_gqs, msgs);
}

void unconstrain_array(const std::vector<double>& params_constrained_r,
std::vector<double>& params_r,
std::ostream* msgs = nullptr) const override {
return static_cast<const M*>(this)->unconstrain_array(params_constrained_r,
params_r, msgs);
}

void transform_inits(const io::var_context& context,
Eigen::VectorXd& params_r,
std::ostream* msgs) const override {
Expand Down
11 changes: 4 additions & 7 deletions src/stan/services/sample/standalone_gqs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,13 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws,
std::vector<std::vector<size_t>> param_dimss;
get_model_parameters(model, param_names, param_dimss);

std::vector<int> dummy_params_i;
std::vector<double> unconstrained_params_r;
std::vector<double> row(draws.cols());

for (size_t i = 0; i < draws.rows(); ++i) {
dummy_params_i.clear();
unconstrained_params_r.clear();
Eigen::Map<Eigen::VectorXd>(&row[0], draws.cols()) = draws.row(i);
try {
stan::io::array_var_context context(param_names, draws.row(i),
param_dimss);
model.transform_inits(context, dummy_params_i, unconstrained_params_r,
&msg);
model.unconstrain_array(row, unconstrained_params_r, &msg);
} catch (const std::exception &e) {
if (msg.str().length() > 0)
logger.error(msg);
Expand Down
188 changes: 188 additions & 0 deletions src/test/test-models/good/model/parameters.inits.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"theta": 3.35045497556837,
"sigma": [ 1.0678931954592095, 0.14205376411270393, 2.7743965352097821 ],
"mu": [ 2.1285977303026997, 2.3570177773034975, 5.6279470675750378 ],
"alpha": [
[
[
0.12208382701175115, 0.096277220888585938, 0.37179118427596042,
0.40984776782370236
],
[
0.41367090546791735, 0.4243728017966204, 0.074953157708614424,
0.087003135026847922
],
[
0.37512017100653144, 0.22485649523657397, 0.1430831011691662,
0.2569402325877283
]
],
[
[
0.2724045704160345, 0.36283264933917447, 0.19251820791261157,
0.17224457233217949
],
[
0.10266646872740415, 0.11717509867472625, 0.5250961979218981,
0.2550622346759715
],
[
0.052691526059705653, 0.39621004190419407, 0.1746323105072301,
0.3764661215288701
]
]
],
"cm": [
[
[ 2.2161729152433294, 4.2735083903241762 ],
[ 2.3228288774921975, 4.2551781219057006 ],
[ 4.0817896796501536, 4.4410534765422955 ],
[ 5.777652580986282, 2.702994187045916 ]
],
[
[ 4.4021064947227728, 6.2791571000818873 ],
[ 2.3751795650893412, 3.2065345846875557 ],
[ 3.333808860667193, 2.6610201337431665 ],
[ 2.6936180212748866, 5.8592677493484269 ]
],
[
[ 4.1249059398888281, 4.6356500684062221 ],
[ 4.1807348015868211, 6.38945134357675 ],
[ 5.6716804030142018, 6.9022942821354389 ],
[ 3.1672905280367103, 4.8241949183238049 ]
]
],
"L_Omega": [
[ 1.9177988996943971, 0.0, 0.0 ],
[ 1.4456130219514518, 1.0865133122191057, 0.0 ],
[ 1.6437414070749543, 0.50344658420634625, 1.7675118195549122 ]
],
"L_Corr": [
[ 1.0, 0.0, 0.0 ],
[ 0.84721537574654238, 0.5312495713867873, 0.0 ],
[ 0.60748819602989068, 0.60437187989848917, 0.51545389946368159 ]
],
"Omega": [
[
[
[
[ 1.0, 0.68939029205051539, 0.504611238921639 ],
[ 0.68939029205051539, 1.0000000000000002, 0.78401979256357612 ],
[ 0.504611238921639, 0.78401979256357612, 1.0 ]
],
[
[ 1.0, 0.74541410105288353, 0.30190927376354765 ],
[ 0.74541410105288353, 1.0, 0.44988416836084927 ],
[ 0.30190927376354765, 0.44988416836084927, 0.99999999999999978 ]
]
],
[
[
[ 1.0, 0.55326714683507616, 0.12986029237307539 ],
[ 0.55326714683507616, 1.0, 0.16728991788128289 ],
[ 0.12986029237307539, 0.16728991788128289, 1.0000000000000002 ]
],
[
[ 1.0, 0.721237684869733, 0.846045734857208 ],
[ 0.721237684869733, 1.0, 0.76818273298484252 ],
[ 0.846045734857208, 0.76818273298484252, 0.99999999999999989 ]
]
],
[
[
[ 1.0, 0.77072046015371076, 0.19772724769821418 ],
[ 0.77072046015371076, 1.0, 0.77299666827811075 ],
[ 0.19772724769821418, 0.77299666827811075, 1.0 ]
],
[
[ 1.0, 0.77442690127950775, 0.34050319470228807 ],
[ 0.77442690127950775, 1.0, 0.33517315284975119 ],
[ 0.34050319470228807, 0.33517315284975119, 0.99999999999999989 ]
]
]
],
[
[
[
[ 1.0, 0.824858543378145, 0.48649005735757422 ],
[ 0.824858543378145, 1.0000000000000002, 0.68242926355025413 ],
[ 0.48649005735757422, 0.68242926355025413, 0.99999999999999989 ]
],
[
[ 1.0, 0.70338407009562276, 0.57632477902590085 ],
[ 0.70338407009562276, 1.0000000000000002, 0.494867050063337 ],
[ 0.57632477902590085, 0.494867050063337, 0.99999999999999989 ]
]
],
[
[
[ 1.0, 0.57297498396235214, 0.41285332825575527 ],
[ 0.57297498396235214, 0.99999999999999989, 0.58082681602684771 ],
[ 0.41285332825575527, 0.58082681602684771, 0.99999999999999989 ]
],
[
[ 1.0, 0.96800009996808856, 0.77013498444134976 ],
[ 0.96800009996808856, 1.0, 0.75291247044061227 ],
[ 0.77013498444134976, 0.75291247044061227, 0.99999999999999978 ]
]
],
[
[
[ 1.0, 0.31657617071814048, 0.036375889145384352 ],
[ 0.31657617071814048, 1.0, 0.89147176254678817 ],
[ 0.036375889145384352, 0.89147176254678817, 1.0000000000000002 ]
],
[
[ 1.0, 0.999993509817939, 0.25954691874754277 ],
[ 0.999993509817939, 0.99999999999999989, 0.262726899218798 ],
[ 0.25954691874754277, 0.262726899218798, 0.99999999999999978 ]
]
]
]
],
"cv": [
[
[
[
[ 3.5433970587540458, 6.2950770670120129 ],
[ 5.6201329242297717, 2.3146297090598931 ],
[ 5.1462314404870781, 3.2721399347170417 ],
[ 4.0656750934461474, 3.2411893901314164 ]
],
[
[ 6.6251907292217185, 6.9002416987352486 ],
[ 5.0112130349146735, 3.6900697736636112 ],
[ 2.2371032314273811, 5.2247184067767574 ],
[ 2.17068504261216, 2.2456130023389376 ]
],
[
[ 5.80176022175848, 5.91129971601706 ],
[ 4.3864026258516891, 3.798097280619011 ],
[ 2.7099329809140325, 5.8044690521155928 ],
[ 4.4681514496492873, 4.0817140157307108 ]
]
],
[
[
[ 4.8785830147136737, 4.3952007395396215 ],
[ 3.6619384086120315, 2.8332738265524342 ],
[ 6.2210444243377845, 2.2427497248475379 ],
[ 3.7985260179585287, 2.8924615842522448 ]
],
[
[ 6.9248095019047611, 2.0818481647367837 ],
[ 2.1600854255133495, 4.6074404412211223 ],
[ 4.0638348657136438, 4.2456437508499167 ],
[ 6.8999181654782307, 2.6275592155970942 ]
],
[
[ 3.9680861372977834, 5.7360991610562344 ],
[ 4.18537248038814, 4.2370654246554018 ],
[ 6.9967852481487673, 3.2049914753910862 ],
[ 5.1202970072734892, 3.2138559765944259 ]
]
]
]
],
"p": 0.44452773799726492
}
21 changes: 21 additions & 0 deletions src/test/test-models/good/model/parameters.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// used to test parameter serialization/deserialization
parameters {
real theta;
array[3] real<lower=0> sigma;
vector[3] mu;
array[2, 3] simplex[4] alpha;
complex_matrix[3, 4] cm;
cholesky_factor_cov[3] L_Omega;
cholesky_factor_corr[3] L_Corr;
array[2, 3, 2] corr_matrix[3] Omega;
array[1, 2, 3] complex_vector[4] cv;
real<lower=0, upper=1> p;
}
transformed parameters {
vector[3] mu2;
mu2 = mu + 1;
}
generated quantities {
array[3] real y;
y = normal_rng(mu2, sigma);
}
91 changes: 91 additions & 0 deletions src/test/unit/model/array_functions_roundtrip_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <stan/model/log_prob_grad.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/json/json_data.hpp>
#include <test/test-models/good/model/parameters.hpp>
#include <test/unit/util.hpp>
#include <gtest/gtest.h>

class ModelArrayFunctionsRoundtripTest : public testing::Test {
public:
ModelArrayFunctionsRoundtripTest()
: model(context, 0, nullptr), rng(12324232), inits(nullptr) {
out.str("");

std::vector<std::string> json_path = {
"src", "test", "test-models", "good", "model", "parameters.inits.json"};
std::string filename = paths_to_fname(json_path);
std::ifstream in(filename);
inits = new stan::json::json_data(in);
}

~ModelArrayFunctionsRoundtripTest() { delete inits; }
WardBrian marked this conversation as resolved.
Show resolved Hide resolved

stan::io::empty_var_context context;
stan::io::var_context* inits;
stan_model model;
std::stringstream out;
boost::ecuyer1988 rng;

/**
* Test that the unconstrain_array function is the inverse of the
* write_array function. This tests the Eigen overloads.
*
* This calls transform_inits, write_array, and then unconstrain_array
* and asserts that the output of unconstrain_array is the same as the
* output of transform_inits.
*/
void eigen_round_trip(bool include_gq, bool include_tp) {
Eigen::VectorXd init_vector;
model.transform_inits(*inits, init_vector, &out);

Eigen::VectorXd written_vector;
model.write_array(rng, init_vector, written_vector, include_gq, include_tp,
&out);

Eigen::VectorXd recovered_vector;
model.unconstrain_array(written_vector, recovered_vector, &out);

EXPECT_MATRIX_NEAR(init_vector, recovered_vector, 1e-10);
EXPECT_EQ("", out.str());
}

/**
* Same as eigen_round_trip but for the std::vector overloads
*/
void std_vec_round_trip(bool include_gq, bool include_tp) {
std::vector<int> unused;
std::vector<double> init_vector;
model.transform_inits(*inits, unused, init_vector, &out);

std::vector<double> written_vector;
model.write_array(rng, init_vector, unused, written_vector, include_gq,
include_tp, &out);

std::vector<double> recovered_vector;
model.unconstrain_array(written_vector, recovered_vector, &out);

for (int i = 0; i < init_vector.size(); i++) {
EXPECT_NEAR(init_vector[i], recovered_vector[i], 1e-10);
}

EXPECT_EQ("", out.str());
}
};

// test all combinations of include_gq and include_tp.
// unconstrain_array should ignore them as they appear at the end
// of the written vectors

TEST_F(ModelArrayFunctionsRoundtripTest, eigen_overloads) {
eigen_round_trip(false, false);
eigen_round_trip(false, true);
eigen_round_trip(true, false);
eigen_round_trip(true, true);
}

TEST_F(ModelArrayFunctionsRoundtripTest, std_vector_overloads) {
std_vec_round_trip(false, false);
std_vec_round_trip(false, true);
std_vec_round_trip(true, false);
std_vec_round_trip(true, true);
}
Loading