Skip to content

Commit

Permalink
Merge pull request #3179 from stan-dev/model-base-unconstrain-array
Browse files Browse the repository at this point in the history
Expose new base class method 'unconstrain_array'
  • Loading branch information
WardBrian authored May 4, 2023
2 parents da13a2f + 02583ce commit d2203f1
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 7 deletions.
30 changes: 30 additions & 0 deletions src/stan/model/model_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,21 @@ class model_base : public prob_grad {
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const = 0;

/**
* Convert the specified sequence of constrained parameters to a
* sequence of unconstrained parameters.
*
* This is the inverse of write_array. The output will be resized
* if necessary to match the number of unconstrained parameters.
*
* @param[in] params_r_constrained constrained parameters input
* @param[in,out] params_r unconstrained parameters produced
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void unconstrain_array(const Eigen::VectorXd& params_constrained_r,
Eigen::VectorXd& params_r,
std::ostream* msgs = nullptr) const = 0;

// TODO(carpenter): cut redundant std::vector versions from here ===

/**
Expand Down Expand Up @@ -605,6 +620,21 @@ class model_base : public prob_grad {
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const = 0;

/**
* Convert the specified sequence of constrained parameters to a
* sequence of unconstrained parameters.
*
* This is the inverse of write_array. The output will be resized
* if necessary to match the number of unconstrained parameters.
*
* @param[in] params_r_constrained constrained parameters input
* @param[in,out] params_r unconstrained parameters produced
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void unconstrain_array(
const std::vector<double>& params_constrained_r,
std::vector<double>& params_r, std::ostream* msgs = nullptr) const = 0;

#ifdef STAN_MODEL_FVAR_VAR

/**
Expand Down
14 changes: 14 additions & 0 deletions src/stan/model/model_base_crtp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ class model_base_crtp : public stan::model::model_base {
rng, theta, vars, include_tparams, include_gqs, msgs);
}

void unconstrain_array(const Eigen::VectorXd& params_constrained_r,
Eigen::VectorXd& params_r,
std::ostream* msgs = nullptr) const override {
return static_cast<const M*>(this)->unconstrain_array(params_constrained_r,
params_r, msgs);
}

// TODO(carpenter): remove redundant std::vector methods below here =====
// ======================================================================

Expand Down Expand Up @@ -203,6 +210,13 @@ class model_base_crtp : public stan::model::model_base {
rng, theta, theta_i, vars, include_tparams, include_gqs, msgs);
}

void unconstrain_array(const std::vector<double>& params_constrained_r,
std::vector<double>& params_r,
std::ostream* msgs = nullptr) const override {
return static_cast<const M*>(this)->unconstrain_array(params_constrained_r,
params_r, msgs);
}

void transform_inits(const io::var_context& context,
Eigen::VectorXd& params_r,
std::ostream* msgs) const override {
Expand Down
11 changes: 4 additions & 7 deletions src/stan/services/sample/standalone_gqs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,13 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws,
std::vector<std::vector<size_t>> param_dimss;
get_model_parameters(model, param_names, param_dimss);

std::vector<int> dummy_params_i;
std::vector<double> unconstrained_params_r;
std::vector<double> row(draws.cols());

for (size_t i = 0; i < draws.rows(); ++i) {
dummy_params_i.clear();
unconstrained_params_r.clear();
Eigen::Map<Eigen::VectorXd>(&row[0], draws.cols()) = draws.row(i);
try {
stan::io::array_var_context context(param_names, draws.row(i),
param_dimss);
model.transform_inits(context, dummy_params_i, unconstrained_params_r,
&msg);
model.unconstrain_array(row, unconstrained_params_r, &msg);
} catch (const std::exception &e) {
if (msg.str().length() > 0)
logger.error(msg);
Expand Down
188 changes: 188 additions & 0 deletions src/test/test-models/good/model/parameters.inits.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"theta": 3.35045497556837,
"sigma": [ 1.0678931954592095, 0.14205376411270393, 2.7743965352097821 ],
"mu": [ 2.1285977303026997, 2.3570177773034975, 5.6279470675750378 ],
"alpha": [
[
[
0.12208382701175115, 0.096277220888585938, 0.37179118427596042,
0.40984776782370236
],
[
0.41367090546791735, 0.4243728017966204, 0.074953157708614424,
0.087003135026847922
],
[
0.37512017100653144, 0.22485649523657397, 0.1430831011691662,
0.2569402325877283
]
],
[
[
0.2724045704160345, 0.36283264933917447, 0.19251820791261157,
0.17224457233217949
],
[
0.10266646872740415, 0.11717509867472625, 0.5250961979218981,
0.2550622346759715
],
[
0.052691526059705653, 0.39621004190419407, 0.1746323105072301,
0.3764661215288701
]
]
],
"cm": [
[
[ 2.2161729152433294, 4.2735083903241762 ],
[ 2.3228288774921975, 4.2551781219057006 ],
[ 4.0817896796501536, 4.4410534765422955 ],
[ 5.777652580986282, 2.702994187045916 ]
],
[
[ 4.4021064947227728, 6.2791571000818873 ],
[ 2.3751795650893412, 3.2065345846875557 ],
[ 3.333808860667193, 2.6610201337431665 ],
[ 2.6936180212748866, 5.8592677493484269 ]
],
[
[ 4.1249059398888281, 4.6356500684062221 ],
[ 4.1807348015868211, 6.38945134357675 ],
[ 5.6716804030142018, 6.9022942821354389 ],
[ 3.1672905280367103, 4.8241949183238049 ]
]
],
"L_Omega": [
[ 1.9177988996943971, 0.0, 0.0 ],
[ 1.4456130219514518, 1.0865133122191057, 0.0 ],
[ 1.6437414070749543, 0.50344658420634625, 1.7675118195549122 ]
],
"L_Corr": [
[ 1.0, 0.0, 0.0 ],
[ 0.84721537574654238, 0.5312495713867873, 0.0 ],
[ 0.60748819602989068, 0.60437187989848917, 0.51545389946368159 ]
],
"Omega": [
[
[
[
[ 1.0, 0.68939029205051539, 0.504611238921639 ],
[ 0.68939029205051539, 1.0000000000000002, 0.78401979256357612 ],
[ 0.504611238921639, 0.78401979256357612, 1.0 ]
],
[
[ 1.0, 0.74541410105288353, 0.30190927376354765 ],
[ 0.74541410105288353, 1.0, 0.44988416836084927 ],
[ 0.30190927376354765, 0.44988416836084927, 0.99999999999999978 ]
]
],
[
[
[ 1.0, 0.55326714683507616, 0.12986029237307539 ],
[ 0.55326714683507616, 1.0, 0.16728991788128289 ],
[ 0.12986029237307539, 0.16728991788128289, 1.0000000000000002 ]
],
[
[ 1.0, 0.721237684869733, 0.846045734857208 ],
[ 0.721237684869733, 1.0, 0.76818273298484252 ],
[ 0.846045734857208, 0.76818273298484252, 0.99999999999999989 ]
]
],
[
[
[ 1.0, 0.77072046015371076, 0.19772724769821418 ],
[ 0.77072046015371076, 1.0, 0.77299666827811075 ],
[ 0.19772724769821418, 0.77299666827811075, 1.0 ]
],
[
[ 1.0, 0.77442690127950775, 0.34050319470228807 ],
[ 0.77442690127950775, 1.0, 0.33517315284975119 ],
[ 0.34050319470228807, 0.33517315284975119, 0.99999999999999989 ]
]
]
],
[
[
[
[ 1.0, 0.824858543378145, 0.48649005735757422 ],
[ 0.824858543378145, 1.0000000000000002, 0.68242926355025413 ],
[ 0.48649005735757422, 0.68242926355025413, 0.99999999999999989 ]
],
[
[ 1.0, 0.70338407009562276, 0.57632477902590085 ],
[ 0.70338407009562276, 1.0000000000000002, 0.494867050063337 ],
[ 0.57632477902590085, 0.494867050063337, 0.99999999999999989 ]
]
],
[
[
[ 1.0, 0.57297498396235214, 0.41285332825575527 ],
[ 0.57297498396235214, 0.99999999999999989, 0.58082681602684771 ],
[ 0.41285332825575527, 0.58082681602684771, 0.99999999999999989 ]
],
[
[ 1.0, 0.96800009996808856, 0.77013498444134976 ],
[ 0.96800009996808856, 1.0, 0.75291247044061227 ],
[ 0.77013498444134976, 0.75291247044061227, 0.99999999999999978 ]
]
],
[
[
[ 1.0, 0.31657617071814048, 0.036375889145384352 ],
[ 0.31657617071814048, 1.0, 0.89147176254678817 ],
[ 0.036375889145384352, 0.89147176254678817, 1.0000000000000002 ]
],
[
[ 1.0, 0.999993509817939, 0.25954691874754277 ],
[ 0.999993509817939, 0.99999999999999989, 0.262726899218798 ],
[ 0.25954691874754277, 0.262726899218798, 0.99999999999999978 ]
]
]
]
],
"cv": [
[
[
[
[ 3.5433970587540458, 6.2950770670120129 ],
[ 5.6201329242297717, 2.3146297090598931 ],
[ 5.1462314404870781, 3.2721399347170417 ],
[ 4.0656750934461474, 3.2411893901314164 ]
],
[
[ 6.6251907292217185, 6.9002416987352486 ],
[ 5.0112130349146735, 3.6900697736636112 ],
[ 2.2371032314273811, 5.2247184067767574 ],
[ 2.17068504261216, 2.2456130023389376 ]
],
[
[ 5.80176022175848, 5.91129971601706 ],
[ 4.3864026258516891, 3.798097280619011 ],
[ 2.7099329809140325, 5.8044690521155928 ],
[ 4.4681514496492873, 4.0817140157307108 ]
]
],
[
[
[ 4.8785830147136737, 4.3952007395396215 ],
[ 3.6619384086120315, 2.8332738265524342 ],
[ 6.2210444243377845, 2.2427497248475379 ],
[ 3.7985260179585287, 2.8924615842522448 ]
],
[
[ 6.9248095019047611, 2.0818481647367837 ],
[ 2.1600854255133495, 4.6074404412211223 ],
[ 4.0638348657136438, 4.2456437508499167 ],
[ 6.8999181654782307, 2.6275592155970942 ]
],
[
[ 3.9680861372977834, 5.7360991610562344 ],
[ 4.18537248038814, 4.2370654246554018 ],
[ 6.9967852481487673, 3.2049914753910862 ],
[ 5.1202970072734892, 3.2138559765944259 ]
]
]
]
],
"p": 0.44452773799726492
}
21 changes: 21 additions & 0 deletions src/test/test-models/good/model/parameters.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// used to test parameter serialization/deserialization
parameters {
real theta;
array[3] real<lower=0> sigma;
vector[3] mu;
array[2, 3] simplex[4] alpha;
complex_matrix[3, 4] cm;
cholesky_factor_cov[3] L_Omega;
cholesky_factor_corr[3] L_Corr;
array[2, 3, 2] corr_matrix[3] Omega;
array[1, 2, 3] complex_vector[4] cv;
real<lower=0, upper=1> p;
}
transformed parameters {
vector[3] mu2;
mu2 = mu + 1;
}
generated quantities {
array[3] real y;
y = normal_rng(mu2, sigma);
}
89 changes: 89 additions & 0 deletions src/test/unit/model/array_functions_roundtrip_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include <stan/model/log_prob_grad.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/json/json_data.hpp>
#include <test/test-models/good/model/parameters.hpp>
#include <test/unit/util.hpp>
#include <gtest/gtest.h>

auto get_init_json() {
std::vector<std::string> json_path = {
"src", "test", "test-models", "good", "model", "parameters.inits.json"};
std::string filename = paths_to_fname(json_path);
std::ifstream in(filename);
return std::unique_ptr<stan::io::var_context>(new stan::json::json_data(in));
}

class ModelArrayFunctionsRoundtripTest : public testing::Test {
public:
ModelArrayFunctionsRoundtripTest()
: model(context, 0, nullptr), rng(12324232), inits(get_init_json()) {}

stan::io::empty_var_context context;
std::unique_ptr<stan::io::var_context> inits;
stan_model model;
std::stringstream out;
boost::ecuyer1988 rng;

/**
* Test that the unconstrain_array function is the inverse of the
* write_array function. This tests the Eigen overloads.
*
* This calls transform_inits, write_array, and then unconstrain_array
* and asserts that the output of unconstrain_array is the same as the
* output of transform_inits.
*/
void eigen_round_trip(bool include_gq, bool include_tp) {
Eigen::VectorXd init_vector;
model.transform_inits(*inits, init_vector, &out);

Eigen::VectorXd written_vector;
model.write_array(rng, init_vector, written_vector, include_gq, include_tp,
&out);

Eigen::VectorXd recovered_vector;
model.unconstrain_array(written_vector, recovered_vector, &out);

EXPECT_MATRIX_NEAR(init_vector, recovered_vector, 1e-10);
EXPECT_EQ("", out.str());
}

/**
* Same as eigen_round_trip but for the std::vector overloads
*/
void std_vec_round_trip(bool include_gq, bool include_tp) {
std::vector<int> unused;
std::vector<double> init_vector;
model.transform_inits(*inits, unused, init_vector, &out);

std::vector<double> written_vector;
model.write_array(rng, init_vector, unused, written_vector, include_gq,
include_tp, &out);

std::vector<double> recovered_vector;
model.unconstrain_array(written_vector, recovered_vector, &out);

for (int i = 0; i < init_vector.size(); i++) {
EXPECT_NEAR(init_vector[i], recovered_vector[i], 1e-10);
}

EXPECT_EQ("", out.str());
}
};

// test all combinations of include_gq and include_tp.
// unconstrain_array should ignore them as they appear at the end
// of the written vectors

TEST_F(ModelArrayFunctionsRoundtripTest, eigen_overloads) {
eigen_round_trip(false, false);
eigen_round_trip(false, true);
eigen_round_trip(true, false);
eigen_round_trip(true, true);
}

TEST_F(ModelArrayFunctionsRoundtripTest, std_vector_overloads) {
std_vec_round_trip(false, false);
std_vec_round_trip(false, true);
std_vec_round_trip(true, false);
std_vec_round_trip(true, true);
}
Loading

0 comments on commit d2203f1

Please sign in to comment.