From 7d2f9c64dc5bf89734e84a7ccb502eaf52fe9114 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 6 Apr 2023 09:19:15 -0400 Subject: [PATCH 1/9] Expose new base class method 'unconstrain_array' --- src/stan/model/model_base.hpp | 31 +++++++++++++++++++++ src/stan/model/model_base_crtp.hpp | 7 ----- src/stan/services/sample/standalone_gqs.hpp | 11 +++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/stan/model/model_base.hpp b/src/stan/model/model_base.hpp index 45972db8284..f955dfeec35 100644 --- a/src/stan/model/model_base.hpp +++ b/src/stan/model/model_base.hpp @@ -368,6 +368,22 @@ class model_base : public prob_grad { bool include_tparams = true, bool include_gqs = true, std::ostream* msgs = 0) const = 0; + + /** + * Convert the specified sequence of constrained parameters to a + * sequence of unconstrained parameters. + * + * This is the inverse of write_array. The output will be resized + * if necessary to match the number of unconstrained parameters. + * + * @param[in] params_r_constrained constrained parameters input + * @param[in,out] params_r unconstrained parameters produced + * @param[in,out] msgs msgs stream to which messages are written + */ + virtual void unconstrain_array(const Eigen::VectorXd& params_constrained_r, + Eigen::VectorXd& params_r, + std::ostream* msgs = nullptr) const = 0; + // TODO(carpenter): cut redundant std::vector versions from here === /** @@ -605,6 +621,21 @@ class model_base : public prob_grad { bool include_tparams = true, bool include_gqs = true, std::ostream* msgs = 0) const = 0; + /** + * Convert the specified sequence of constrained parameters to a + * sequence of unconstrained parameters. + * + * This is the inverse of write_array. The output will be resized + * if necessary to match the number of unconstrained parameters. + * + * @param[in] params_r_constrained constrained parameters input + * @param[in,out] params_r unconstrained parameters produced + * @param[in,out] msgs msgs stream to which messages are written + */ + virtual void unconstrain_array(const std::vector& params_constrained_r, + std::vector& params_r, + std::ostream* msgs = nullptr) const = 0; + #ifdef STAN_MODEL_FVAR_VAR /** diff --git a/src/stan/model/model_base_crtp.hpp b/src/stan/model/model_base_crtp.hpp index 4645d706a22..96f85a5b3e0 100644 --- a/src/stan/model/model_base_crtp.hpp +++ b/src/stan/model/model_base_crtp.hpp @@ -203,13 +203,6 @@ class model_base_crtp : public stan::model::model_base { rng, theta, theta_i, vars, include_tparams, include_gqs, msgs); } - void transform_inits(const io::var_context& context, - Eigen::VectorXd& params_r, - std::ostream* msgs) const override { - return static_cast(this)->transform_inits(context, params_r, - msgs); - } - #ifdef STAN_MODEL_FVAR_VAR /** diff --git a/src/stan/services/sample/standalone_gqs.hpp b/src/stan/services/sample/standalone_gqs.hpp index 12528840bac..3f99d049167 100644 --- a/src/stan/services/sample/standalone_gqs.hpp +++ b/src/stan/services/sample/standalone_gqs.hpp @@ -96,16 +96,13 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws, std::vector> param_dimss; get_model_parameters(model, param_names, param_dimss); - std::vector dummy_params_i; std::vector unconstrained_params_r; + std::vector row(draws.cols()); + for (size_t i = 0; i < draws.rows(); ++i) { - dummy_params_i.clear(); - unconstrained_params_r.clear(); + Eigen::Map(&row[0], draws.cols()) = draws.row(i); try { - stan::io::array_var_context context(param_names, draws.row(i), - param_dimss); - model.transform_inits(context, dummy_params_i, unconstrained_params_r, - &msg); + model.unconstrain_array(row, unconstrained_params_r, &msg); } catch (const std::exception &e) { if (msg.str().length() > 0) logger.error(msg); From 9f3952f857f8c39868e65fe21a1ea8ceb5a6d9a9 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 11 Apr 2023 12:31:51 -0400 Subject: [PATCH 2/9] Update tests --- src/test/unit/model/model_base_crtp_test.cpp | 7 +++++++ src/test/unit/model/model_base_test.cpp | 8 ++++++++ src/test/unit/services/util/mcmc_writer_test.cpp | 7 +++++++ 3 files changed, 22 insertions(+) diff --git a/src/test/unit/model/model_base_crtp_test.cpp b/src/test/unit/model/model_base_crtp_test.cpp index 869c6ed34f2..1fb5db0a808 100644 --- a/src/test/unit/model/model_base_crtp_test.cpp +++ b/src/test/unit/model/model_base_crtp_test.cpp @@ -101,6 +101,13 @@ struct mock_model : public stan::model::model_base_crtp { std::vector& params_r_constrained, bool include_tparams, bool include_gqs, std::ostream* msgs) const {} + + void unconstrain_array(const Eigen::VectorXd& params_constrained_r, + Eigen::VectorXd& params_r, + std::ostream* msgs = nullptr) const override {} + void unconstrain_array(const std::vector& params_constrained_r, + std::vector& params_r, + std::ostream* msgs = nullptr) const override {} }; TEST(model, modelBaseInheritance) { diff --git a/src/test/unit/model/model_base_test.cpp b/src/test/unit/model/model_base_test.cpp index 4a97ced1f90..53434c67739 100644 --- a/src/test/unit/model/model_base_test.cpp +++ b/src/test/unit/model/model_base_test.cpp @@ -83,6 +83,10 @@ struct mock_model : public stan::model::model_base { Eigen::VectorXd& params_constrained_r, bool include_tparams, bool include_gqs, std::ostream* msgs) const override {} + void unconstrain_array(const Eigen::VectorXd& params_constrained_r, + Eigen::VectorXd& params_r, + std::ostream* msgs = nullptr) const override {} + double log_prob(std::vector& params_r, std::vector& params_i, std::ostream* msgs) const override { return 11; @@ -141,6 +145,10 @@ struct mock_model : public stan::model::model_base { bool include_tparams, bool include_gqs, std::ostream* msgs) const override {} + void unconstrain_array(const std::vector& params_constrained_r, + std::vector& params_r, + std::ostream* msgs = nullptr) const override {} + #ifdef STAN_MODEL_FVAR_VAR stan::math::fvar log_prob( diff --git a/src/test/unit/services/util/mcmc_writer_test.cpp b/src/test/unit/services/util/mcmc_writer_test.cpp index f1d027403ae..49152526dc4 100644 --- a/src/test/unit/services/util/mcmc_writer_test.cpp +++ b/src/test/unit/services/util/mcmc_writer_test.cpp @@ -149,6 +149,13 @@ class throwing_model : public stan::model::model_base_crtp { const stan::io::var_context& context, Eigen::Matrix& params_r, std::ostream* pstream__ = nullptr) const {} + + void unconstrain_array(const Eigen::VectorXd& params_constrained_r, + Eigen::VectorXd& params_r, + std::ostream* msgs = nullptr) const override {} + void unconstrain_array(const std::vector& params_constrained_r, + std::vector& params_r, + std::ostream* msgs = nullptr) const override {} }; } // namespace test From 284861468fccdea0754da107cd38b98fc9033c2c Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 11 Apr 2023 14:05:54 -0400 Subject: [PATCH 3/9] Add tests of round-trip with write_array --- .../good/model/parameters.inits.json | 188 ++++++++++++++++++ .../test-models/good/model/parameters.stan | 13 ++ .../model/array_functions_roundtrip_test.cpp | 60 ++++++ src/test/unit/util.hpp | 18 ++ 4 files changed, 279 insertions(+) create mode 100644 src/test/test-models/good/model/parameters.inits.json create mode 100644 src/test/test-models/good/model/parameters.stan create mode 100644 src/test/unit/model/array_functions_roundtrip_test.cpp diff --git a/src/test/test-models/good/model/parameters.inits.json b/src/test/test-models/good/model/parameters.inits.json new file mode 100644 index 00000000000..41d2deb0d2c --- /dev/null +++ b/src/test/test-models/good/model/parameters.inits.json @@ -0,0 +1,188 @@ +{ + "theta": 6.0259236924236479, + "sigma": [ 0.49404628963957464, 1.7890770781045591 ], + "mu": [ 2.3614558521457374, 4.7013980779374336, 5.9669822366479046 ], + "alpha": [ + [ + [ + 0.48519605663895082, 0.072476485656173609, 0.2016497322433165, + 0.24067772546155902 + ], + [ + 0.29194266525441742, 0.18015647786812236, 0.182471506206822, + 0.34542935067063812 + ], + [ + 0.0088315040897732552, 0.17056928511037653, 0.48201658293201283, + 0.33858262786783733 + ] + ], + [ + [ + 0.054263547813664666, 0.68720978364541985, 0.0898770345967127, + 0.16864963394420274 + ], + [ + 0.2342347013480465, 0.52898483516380979, 0.22007908017060804, + 0.016701383317535796 + ], + [ + 0.46151613790635287, 0.21628949924286106, 0.083506221199266831, + 0.23868814165151919 + ] + ] + ], + "cm": [ + [ + [ 5.8742928607162987, 3.37481110970312 ], + [ 3.8109641809714816, 3.479741575288458 ], + [ 4.9095671336661875, 2.9562261117925726 ], + [ 2.4883741076653534, 2.6270517434309695 ] + ], + [ + [ 4.2306529269790234, 2.8089199427657352 ], + [ 2.1435389197334103, 6.3227988268983939 ], + [ 4.4697237280322586, 3.9097197582393859 ], + [ 2.4234778406541011, 6.5293446112399147 ] + ], + [ + [ 6.3744994650304063, 3.3658103239740145 ], + [ 2.6241457673137845, 5.8378916448461453 ], + [ 3.0834646138512292, 5.086041093051799 ], + [ 3.3537337347439853, 2.0610400332102468 ] + ] + ], + "L_Omega": [ + [ 1.6859019351132143, 0.0, 0.0 ], + [ 1.5100530757160859, 1.9540945519975061, 0.0 ], + [ 0.80305935498225411, 0.41755234613412578, 0.34527779020597843 ] + ], + "L_Corr": [ + [ 1.0, 0.0, 0.0 ], + [ 0.7252883286260241, 0.68844523410280678, 0.0 ], + [ 0.59817858939472945, 0.19697185555697858, 0.77677825877670681 ] + ], + "Omega": [ + [ + [ + [ + [ 1.0, 0.55001344718079914, 0.66782858658963418 ], + [ 0.55001344718079914, 1.0, 0.98624028384536921 ], + [ 0.66782858658963418, 0.98624028384536921, 0.99999999999999978 ] + ], + [ + [ 1.0, 0.9318909439172286, 0.12646004908579564 ], + [ 0.9318909439172286, 1.0, 0.38310995933766717 ], + [ 0.12646004908579564, 0.38310995933766717, 1.0000000000000004 ] + ] + ], + [ + [ + [ 1.0, 0.49318905583978717, 0.35752652655917544 ], + [ 0.49318905583978717, 0.99999999999999989, 0.71232427460865166 ], + [ 0.35752652655917544, 0.71232427460865166, 1.0 ] + ], + [ + [ 1.0, 0.79324509354994, 0.725578409364175 ], + [ 0.79324509354994, 1.0000000000000002, 0.95577977609788223 ], + [ 0.725578409364175, 0.95577977609788223, 0.99999999999999989 ] + ] + ], + [ + [ + [ 1.0, 0.36584941582821728, 0.13479496359958917 ], + [ 0.36584941582821728, 1.0, 0.56339796833654132 ], + [ 0.13479496359958917, 0.56339796833654132, 1.0000000000000002 ] + ], + [ + [ 1.0, 0.50976927614597622, 0.80207673991446571 ], + [ 0.50976927614597622, 1.0, 0.68850962645333658 ], + [ 0.80207673991446571, 0.68850962645333658, 1.0 ] + ] + ] + ], + [ + [ + [ + [ 1.0, 0.6761981720506417, 0.14048539967195806 ], + [ 0.6761981720506417, 1.0000000000000002, 0.59606763590879885 ], + [ 0.14048539967195806, 0.59606763590879885, 0.99999999999999978 ] + ], + [ + [ 1.0, 0.83267548540315484, 0.1161421868513267 ], + [ 0.83267548540315484, 1.0, 0.36127622503604984 ], + [ 0.1161421868513267, 0.36127622503604984, 1.0 ] + ] + ], + [ + [ + [ 1.0, 0.65094812690160486, 0.24833571106660485 ], + [ 0.65094812690160486, 0.99999999999999978, 0.87861205469182324 ], + [ 0.24833571106660485, 0.87861205469182324, 0.99999999999999989 ] + ], + [ + [ 1.0, 0.87121505620140771, 0.096546271520155144 ], + [ 0.87121505620140771, 1.0, 0.40952490643483219 ], + [ 0.096546271520155144, 0.40952490643483219, 1.0 ] + ] + ], + [ + [ + [ 1.0, 0.84363172177952406, 0.45884050416569433 ], + [ 0.84363172177952406, 1.0, 0.81636519430225407 ], + [ 0.45884050416569433, 0.81636519430225407, 0.99999999999999978 ] + ], + [ + [ 1.0, 0.99893086230667771, 0.51520325383243815 ], + [ 0.99893086230667771, 1.0000000000000002, 0.54737233709009459 ], + [ 0.51520325383243815, 0.54737233709009459, 1.0000000000000002 ] + ] + ] + ] + ], + "cv": [ + [ + [ + [ + [ 5.1663094098648292, 2.6992998584233279 ], + [ 3.516087101086983, 3.8723933583731247 ], + [ 3.0316488865424223, 3.4729341839760846 ], + [ 5.5066971216820511, 3.3859493386390422 ] + ], + [ + [ 6.8658257903710336, 4.3600362419403371 ], + [ 3.9211239216617804, 2.8351155303955791 ], + [ 4.1195086666464125, 3.8023547393683415 ], + [ 6.2580026511104849, 2.1005711764538493 ] + ], + [ + [ 4.7268899914650859, 4.5923050316777463 ], + [ 3.4762358667133579, 4.2442096457788274 ], + [ 3.5455589406474659, 5.4652971907441561 ], + [ 6.3153655925600107, 3.3120627802259204 ] + ] + ], + [ + [ + [ 3.9162392626483817, 3.749689288247267 ], + [ 5.0089472710020306, 5.5700410415594721 ], + [ 5.3697354862617646, 4.9141010702010668 ], + [ 4.7877738836159667, 4.6208347757670065 ] + ], + [ + [ 2.6180703830370833, 5.8691356845339673 ], + [ 6.33506959006537, 6.52981582099901 ], + [ 4.502995465196248, 4.377433887799496 ], + [ 2.7662057362027213, 6.2876357933511784 ] + ], + [ + [ 5.3389919152354084, 4.4303064363263145 ], + [ 2.41243332010433, 2.826547127396049 ], + [ 4.6649955354162387, 6.208905770686683 ], + [ 4.84395213093654, 5.9671366331372635 ] + ] + ] + ] + ], + "p": 0.85714999002788439 +} diff --git a/src/test/test-models/good/model/parameters.stan b/src/test/test-models/good/model/parameters.stan new file mode 100644 index 00000000000..bdd9cfacf19 --- /dev/null +++ b/src/test/test-models/good/model/parameters.stan @@ -0,0 +1,13 @@ +// used to test parameter serialization/deserialization +parameters { + real theta; + array[2] real sigma; + vector[3] mu; + array[2, 3] simplex[4] alpha; + complex_matrix[3, 4] cm; + cholesky_factor_cov[3] L_Omega; + cholesky_factor_corr[3] L_Corr; + array[2, 3, 2] corr_matrix[3] Omega; + array[1, 2, 3] complex_vector[4] cv; + real p; +} diff --git a/src/test/unit/model/array_functions_roundtrip_test.cpp b/src/test/unit/model/array_functions_roundtrip_test.cpp new file mode 100644 index 00000000000..d572324ba9c --- /dev/null +++ b/src/test/unit/model/array_functions_roundtrip_test.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include +#include + +TEST(ModelUtil, write_array_unconstrain_array_roundtrip) { + stan::io::empty_var_context data_var_context; + stan_model model(data_var_context, 0, static_cast(0)); + + std::vector json_path; + json_path = {"src", "test", "test-models", + "good", "model", "parameters.inits.json"}; + std::string filename = paths_to_fname(json_path); + std::ifstream in(filename); + stan::json::json_data inits(in); + + std::stringstream out; + out.str(""); + + // unused in this model but needed for write_array + auto rng = stan::services::util::create_rng(12324232, 1); + + try { + Eigen::VectorXd init_vector; + model.transform_inits(inits, init_vector, &out); + + Eigen::VectorXd written_vector; + model.write_array(rng, init_vector, written_vector, &out); + + Eigen::VectorXd recovered_vector; + model.unconstrain_array(written_vector, recovered_vector, &out); + + EXPECT_MATRIX_NEAR(init_vector, recovered_vector, 1e-10); + EXPECT_EQ("", out.str()); + } catch (...) { + FAIL() << "write_array_unconstrain_array_roundtrip Eigen::VectorXd"; + } + + try { + std::vector unused; + std::vector init_vector; + model.transform_inits(inits, unused, init_vector, &out); + + std::vector written_vector; + model.write_array(rng, init_vector, unused, written_vector, &out); + + std::vector recovered_vector; + model.unconstrain_array(written_vector, recovered_vector, &out); + + for (int i = 0; i < init_vector.size(); i++) { + EXPECT_NEAR(init_vector[i], recovered_vector[i], 1e-10); + } + + EXPECT_EQ("", out.str()); + } catch (...) { + FAIL() << "write_array_unconstrain_array_roundtrip std::vector"; + } +} diff --git a/src/test/unit/util.hpp b/src/test/unit/util.hpp index a5de5bdbff0..5c712b8971f 100644 --- a/src/test/unit/util.hpp +++ b/src/test/unit/util.hpp @@ -89,6 +89,24 @@ void match_csv_columns(const Eigen::MatrixXd& samples, } #endif +#ifndef EXPECT_MATRIX_NEAR +#define EXPECT_MATRIX_NEAR(A, B, DELTA) \ + { \ + using T_A = std::decay_t; \ + using T_B = std::decay_t; \ + const Eigen::Matrix \ + A_eval = A; \ + const Eigen::Matrix \ + B_eval = B; \ + EXPECT_EQ(A_eval.rows(), B_eval.rows()); \ + EXPECT_EQ(A_eval.cols(), B_eval.cols()); \ + for (int i = 0; i < A_eval.size(); i++) \ + EXPECT_NEAR(A_eval(i), B_eval(i), DELTA); \ + } +#endif + /** * Gets the path separator for the OS. * From ea2a1f5a506a85cdf7ec628c99735f421c2e9978 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 11 Apr 2023 15:17:07 -0400 Subject: [PATCH 4/9] Extend tests --- .../good/model/parameters.inits.json | 186 +++++++++--------- .../test-models/good/model/parameters.stan | 10 +- .../model/array_functions_roundtrip_test.cpp | 77 +++++--- 3 files changed, 156 insertions(+), 117 deletions(-) diff --git a/src/test/test-models/good/model/parameters.inits.json b/src/test/test-models/good/model/parameters.inits.json index 41d2deb0d2c..10405a7b886 100644 --- a/src/test/test-models/good/model/parameters.inits.json +++ b/src/test/test-models/good/model/parameters.inits.json @@ -1,141 +1,141 @@ { - "theta": 6.0259236924236479, - "sigma": [ 0.49404628963957464, 1.7890770781045591 ], - "mu": [ 2.3614558521457374, 4.7013980779374336, 5.9669822366479046 ], + "theta": 3.35045497556837, + "sigma": [ 1.0678931954592095, 0.14205376411270393, 2.7743965352097821 ], + "mu": [ 2.1285977303026997, 2.3570177773034975, 5.6279470675750378 ], "alpha": [ [ [ - 0.48519605663895082, 0.072476485656173609, 0.2016497322433165, - 0.24067772546155902 + 0.12208382701175115, 0.096277220888585938, 0.37179118427596042, + 0.40984776782370236 ], [ - 0.29194266525441742, 0.18015647786812236, 0.182471506206822, - 0.34542935067063812 + 0.41367090546791735, 0.4243728017966204, 0.074953157708614424, + 0.087003135026847922 ], [ - 0.0088315040897732552, 0.17056928511037653, 0.48201658293201283, - 0.33858262786783733 + 0.37512017100653144, 0.22485649523657397, 0.1430831011691662, + 0.2569402325877283 ] ], [ [ - 0.054263547813664666, 0.68720978364541985, 0.0898770345967127, - 0.16864963394420274 + 0.2724045704160345, 0.36283264933917447, 0.19251820791261157, + 0.17224457233217949 ], [ - 0.2342347013480465, 0.52898483516380979, 0.22007908017060804, - 0.016701383317535796 + 0.10266646872740415, 0.11717509867472625, 0.5250961979218981, + 0.2550622346759715 ], [ - 0.46151613790635287, 0.21628949924286106, 0.083506221199266831, - 0.23868814165151919 + 0.052691526059705653, 0.39621004190419407, 0.1746323105072301, + 0.3764661215288701 ] ] ], "cm": [ [ - [ 5.8742928607162987, 3.37481110970312 ], - [ 3.8109641809714816, 3.479741575288458 ], - [ 4.9095671336661875, 2.9562261117925726 ], - [ 2.4883741076653534, 2.6270517434309695 ] + [ 2.2161729152433294, 4.2735083903241762 ], + [ 2.3228288774921975, 4.2551781219057006 ], + [ 4.0817896796501536, 4.4410534765422955 ], + [ 5.777652580986282, 2.702994187045916 ] ], [ - [ 4.2306529269790234, 2.8089199427657352 ], - [ 2.1435389197334103, 6.3227988268983939 ], - [ 4.4697237280322586, 3.9097197582393859 ], - [ 2.4234778406541011, 6.5293446112399147 ] + [ 4.4021064947227728, 6.2791571000818873 ], + [ 2.3751795650893412, 3.2065345846875557 ], + [ 3.333808860667193, 2.6610201337431665 ], + [ 2.6936180212748866, 5.8592677493484269 ] ], [ - [ 6.3744994650304063, 3.3658103239740145 ], - [ 2.6241457673137845, 5.8378916448461453 ], - [ 3.0834646138512292, 5.086041093051799 ], - [ 3.3537337347439853, 2.0610400332102468 ] + [ 4.1249059398888281, 4.6356500684062221 ], + [ 4.1807348015868211, 6.38945134357675 ], + [ 5.6716804030142018, 6.9022942821354389 ], + [ 3.1672905280367103, 4.8241949183238049 ] ] ], "L_Omega": [ - [ 1.6859019351132143, 0.0, 0.0 ], - [ 1.5100530757160859, 1.9540945519975061, 0.0 ], - [ 0.80305935498225411, 0.41755234613412578, 0.34527779020597843 ] + [ 1.9177988996943971, 0.0, 0.0 ], + [ 1.4456130219514518, 1.0865133122191057, 0.0 ], + [ 1.6437414070749543, 0.50344658420634625, 1.7675118195549122 ] ], "L_Corr": [ [ 1.0, 0.0, 0.0 ], - [ 0.7252883286260241, 0.68844523410280678, 0.0 ], - [ 0.59817858939472945, 0.19697185555697858, 0.77677825877670681 ] + [ 0.84721537574654238, 0.5312495713867873, 0.0 ], + [ 0.60748819602989068, 0.60437187989848917, 0.51545389946368159 ] ], "Omega": [ [ [ [ - [ 1.0, 0.55001344718079914, 0.66782858658963418 ], - [ 0.55001344718079914, 1.0, 0.98624028384536921 ], - [ 0.66782858658963418, 0.98624028384536921, 0.99999999999999978 ] + [ 1.0, 0.68939029205051539, 0.504611238921639 ], + [ 0.68939029205051539, 1.0000000000000002, 0.78401979256357612 ], + [ 0.504611238921639, 0.78401979256357612, 1.0 ] ], [ - [ 1.0, 0.9318909439172286, 0.12646004908579564 ], - [ 0.9318909439172286, 1.0, 0.38310995933766717 ], - [ 0.12646004908579564, 0.38310995933766717, 1.0000000000000004 ] + [ 1.0, 0.74541410105288353, 0.30190927376354765 ], + [ 0.74541410105288353, 1.0, 0.44988416836084927 ], + [ 0.30190927376354765, 0.44988416836084927, 0.99999999999999978 ] ] ], [ [ - [ 1.0, 0.49318905583978717, 0.35752652655917544 ], - [ 0.49318905583978717, 0.99999999999999989, 0.71232427460865166 ], - [ 0.35752652655917544, 0.71232427460865166, 1.0 ] + [ 1.0, 0.55326714683507616, 0.12986029237307539 ], + [ 0.55326714683507616, 1.0, 0.16728991788128289 ], + [ 0.12986029237307539, 0.16728991788128289, 1.0000000000000002 ] ], [ - [ 1.0, 0.79324509354994, 0.725578409364175 ], - [ 0.79324509354994, 1.0000000000000002, 0.95577977609788223 ], - [ 0.725578409364175, 0.95577977609788223, 0.99999999999999989 ] + [ 1.0, 0.721237684869733, 0.846045734857208 ], + [ 0.721237684869733, 1.0, 0.76818273298484252 ], + [ 0.846045734857208, 0.76818273298484252, 0.99999999999999989 ] ] ], [ [ - [ 1.0, 0.36584941582821728, 0.13479496359958917 ], - [ 0.36584941582821728, 1.0, 0.56339796833654132 ], - [ 0.13479496359958917, 0.56339796833654132, 1.0000000000000002 ] + [ 1.0, 0.77072046015371076, 0.19772724769821418 ], + [ 0.77072046015371076, 1.0, 0.77299666827811075 ], + [ 0.19772724769821418, 0.77299666827811075, 1.0 ] ], [ - [ 1.0, 0.50976927614597622, 0.80207673991446571 ], - [ 0.50976927614597622, 1.0, 0.68850962645333658 ], - [ 0.80207673991446571, 0.68850962645333658, 1.0 ] + [ 1.0, 0.77442690127950775, 0.34050319470228807 ], + [ 0.77442690127950775, 1.0, 0.33517315284975119 ], + [ 0.34050319470228807, 0.33517315284975119, 0.99999999999999989 ] ] ] ], [ [ [ - [ 1.0, 0.6761981720506417, 0.14048539967195806 ], - [ 0.6761981720506417, 1.0000000000000002, 0.59606763590879885 ], - [ 0.14048539967195806, 0.59606763590879885, 0.99999999999999978 ] + [ 1.0, 0.824858543378145, 0.48649005735757422 ], + [ 0.824858543378145, 1.0000000000000002, 0.68242926355025413 ], + [ 0.48649005735757422, 0.68242926355025413, 0.99999999999999989 ] ], [ - [ 1.0, 0.83267548540315484, 0.1161421868513267 ], - [ 0.83267548540315484, 1.0, 0.36127622503604984 ], - [ 0.1161421868513267, 0.36127622503604984, 1.0 ] + [ 1.0, 0.70338407009562276, 0.57632477902590085 ], + [ 0.70338407009562276, 1.0000000000000002, 0.494867050063337 ], + [ 0.57632477902590085, 0.494867050063337, 0.99999999999999989 ] ] ], [ [ - [ 1.0, 0.65094812690160486, 0.24833571106660485 ], - [ 0.65094812690160486, 0.99999999999999978, 0.87861205469182324 ], - [ 0.24833571106660485, 0.87861205469182324, 0.99999999999999989 ] + [ 1.0, 0.57297498396235214, 0.41285332825575527 ], + [ 0.57297498396235214, 0.99999999999999989, 0.58082681602684771 ], + [ 0.41285332825575527, 0.58082681602684771, 0.99999999999999989 ] ], [ - [ 1.0, 0.87121505620140771, 0.096546271520155144 ], - [ 0.87121505620140771, 1.0, 0.40952490643483219 ], - [ 0.096546271520155144, 0.40952490643483219, 1.0 ] + [ 1.0, 0.96800009996808856, 0.77013498444134976 ], + [ 0.96800009996808856, 1.0, 0.75291247044061227 ], + [ 0.77013498444134976, 0.75291247044061227, 0.99999999999999978 ] ] ], [ [ - [ 1.0, 0.84363172177952406, 0.45884050416569433 ], - [ 0.84363172177952406, 1.0, 0.81636519430225407 ], - [ 0.45884050416569433, 0.81636519430225407, 0.99999999999999978 ] + [ 1.0, 0.31657617071814048, 0.036375889145384352 ], + [ 0.31657617071814048, 1.0, 0.89147176254678817 ], + [ 0.036375889145384352, 0.89147176254678817, 1.0000000000000002 ] ], [ - [ 1.0, 0.99893086230667771, 0.51520325383243815 ], - [ 0.99893086230667771, 1.0000000000000002, 0.54737233709009459 ], - [ 0.51520325383243815, 0.54737233709009459, 1.0000000000000002 ] + [ 1.0, 0.999993509817939, 0.25954691874754277 ], + [ 0.999993509817939, 0.99999999999999989, 0.262726899218798 ], + [ 0.25954691874754277, 0.262726899218798, 0.99999999999999978 ] ] ] ] @@ -144,45 +144,45 @@ [ [ [ - [ 5.1663094098648292, 2.6992998584233279 ], - [ 3.516087101086983, 3.8723933583731247 ], - [ 3.0316488865424223, 3.4729341839760846 ], - [ 5.5066971216820511, 3.3859493386390422 ] + [ 3.5433970587540458, 6.2950770670120129 ], + [ 5.6201329242297717, 2.3146297090598931 ], + [ 5.1462314404870781, 3.2721399347170417 ], + [ 4.0656750934461474, 3.2411893901314164 ] ], [ - [ 6.8658257903710336, 4.3600362419403371 ], - [ 3.9211239216617804, 2.8351155303955791 ], - [ 4.1195086666464125, 3.8023547393683415 ], - [ 6.2580026511104849, 2.1005711764538493 ] + [ 6.6251907292217185, 6.9002416987352486 ], + [ 5.0112130349146735, 3.6900697736636112 ], + [ 2.2371032314273811, 5.2247184067767574 ], + [ 2.17068504261216, 2.2456130023389376 ] ], [ - [ 4.7268899914650859, 4.5923050316777463 ], - [ 3.4762358667133579, 4.2442096457788274 ], - [ 3.5455589406474659, 5.4652971907441561 ], - [ 6.3153655925600107, 3.3120627802259204 ] + [ 5.80176022175848, 5.91129971601706 ], + [ 4.3864026258516891, 3.798097280619011 ], + [ 2.7099329809140325, 5.8044690521155928 ], + [ 4.4681514496492873, 4.0817140157307108 ] ] ], [ [ - [ 3.9162392626483817, 3.749689288247267 ], - [ 5.0089472710020306, 5.5700410415594721 ], - [ 5.3697354862617646, 4.9141010702010668 ], - [ 4.7877738836159667, 4.6208347757670065 ] + [ 4.8785830147136737, 4.3952007395396215 ], + [ 3.6619384086120315, 2.8332738265524342 ], + [ 6.2210444243377845, 2.2427497248475379 ], + [ 3.7985260179585287, 2.8924615842522448 ] ], [ - [ 2.6180703830370833, 5.8691356845339673 ], - [ 6.33506959006537, 6.52981582099901 ], - [ 4.502995465196248, 4.377433887799496 ], - [ 2.7662057362027213, 6.2876357933511784 ] + [ 6.9248095019047611, 2.0818481647367837 ], + [ 2.1600854255133495, 4.6074404412211223 ], + [ 4.0638348657136438, 4.2456437508499167 ], + [ 6.8999181654782307, 2.6275592155970942 ] ], [ - [ 5.3389919152354084, 4.4303064363263145 ], - [ 2.41243332010433, 2.826547127396049 ], - [ 4.6649955354162387, 6.208905770686683 ], - [ 4.84395213093654, 5.9671366331372635 ] + [ 3.9680861372977834, 5.7360991610562344 ], + [ 4.18537248038814, 4.2370654246554018 ], + [ 6.9967852481487673, 3.2049914753910862 ], + [ 5.1202970072734892, 3.2138559765944259 ] ] ] ] ], - "p": 0.85714999002788439 + "p": 0.44452773799726492 } diff --git a/src/test/test-models/good/model/parameters.stan b/src/test/test-models/good/model/parameters.stan index bdd9cfacf19..104a456d3dd 100644 --- a/src/test/test-models/good/model/parameters.stan +++ b/src/test/test-models/good/model/parameters.stan @@ -1,7 +1,7 @@ // used to test parameter serialization/deserialization parameters { real theta; - array[2] real sigma; + array[3] real sigma; vector[3] mu; array[2, 3] simplex[4] alpha; complex_matrix[3, 4] cm; @@ -11,3 +11,11 @@ parameters { array[1, 2, 3] complex_vector[4] cv; real p; } +transformed parameters { + vector[3] mu2; + mu2 = mu + 1; +} +generated quantities { + array[3] real y; + y = normal_rng(mu2, sigma); +} diff --git a/src/test/unit/model/array_functions_roundtrip_test.cpp b/src/test/unit/model/array_functions_roundtrip_test.cpp index d572324ba9c..1164aabe22d 100644 --- a/src/test/unit/model/array_functions_roundtrip_test.cpp +++ b/src/test/unit/model/array_functions_roundtrip_test.cpp @@ -5,46 +5,61 @@ #include #include -TEST(ModelUtil, write_array_unconstrain_array_roundtrip) { - stan::io::empty_var_context data_var_context; - stan_model model(data_var_context, 0, static_cast(0)); +class ModelArrayFunctionsRoundtripTest : public testing::Test { + public: + ModelArrayFunctionsRoundtripTest() + : model(context, 0, nullptr), rng(12324232), inits(nullptr) { + out.str(""); - std::vector json_path; - json_path = {"src", "test", "test-models", - "good", "model", "parameters.inits.json"}; - std::string filename = paths_to_fname(json_path); - std::ifstream in(filename); - stan::json::json_data inits(in); + std::vector json_path = { + "src", "test", "test-models", "good", "model", "parameters.inits.json"}; + std::string filename = paths_to_fname(json_path); + std::ifstream in(filename); + inits = new stan::json::json_data(in); + } - std::stringstream out; - out.str(""); + ~ModelArrayFunctionsRoundtripTest() { delete inits; } - // unused in this model but needed for write_array - auto rng = stan::services::util::create_rng(12324232, 1); + stan::io::empty_var_context context; + stan::io::var_context* inits; + stan_model model; + std::stringstream out; + boost::ecuyer1988 rng; - try { + /** + * Test that the unconstrain_array function is the inverse of the + * write_array function. This tests the Eigen overloads. + * + * This calls transform_inits, write_array, and then unconstrain_array + * and asserts that the output of unconstrain_array is the same as the + * output of transform_inits. + */ + void eigen_round_trip(bool include_gq, bool include_tp) { Eigen::VectorXd init_vector; - model.transform_inits(inits, init_vector, &out); + model.transform_inits(*inits, init_vector, &out); Eigen::VectorXd written_vector; - model.write_array(rng, init_vector, written_vector, &out); + model.write_array(rng, init_vector, written_vector, include_gq, include_tp, + &out); Eigen::VectorXd recovered_vector; model.unconstrain_array(written_vector, recovered_vector, &out); EXPECT_MATRIX_NEAR(init_vector, recovered_vector, 1e-10); EXPECT_EQ("", out.str()); - } catch (...) { - FAIL() << "write_array_unconstrain_array_roundtrip Eigen::VectorXd"; } - try { + /** + * Same as eigen_round_trip but for the std::vector overloads + */ + void std_vec_round_trip(bool include_gq, bool include_tp) { std::vector unused; std::vector init_vector; - model.transform_inits(inits, unused, init_vector, &out); + model.transform_inits(*inits, unused, init_vector, &out); std::vector written_vector; - model.write_array(rng, init_vector, unused, written_vector, &out); + model.write_array(rng, init_vector, unused, written_vector, include_gq, + include_tp, &out); std::vector recovered_vector; model.unconstrain_array(written_vector, recovered_vector, &out); @@ -54,7 +69,23 @@ TEST(ModelUtil, write_array_unconstrain_array_roundtrip) { } EXPECT_EQ("", out.str()); - } catch (...) { - FAIL() << "write_array_unconstrain_array_roundtrip std::vector"; } +}; + +// test all combinations of include_gq and include_tp. +// unconstrain_array should ignore them as they appear at the end +// of the written vectors + +TEST_F(ModelArrayFunctionsRoundtripTest, eigen_overloads) { + eigen_round_trip(false, false); + eigen_round_trip(false, true); + eigen_round_trip(true, false); + eigen_round_trip(true, true); +} + +TEST_F(ModelArrayFunctionsRoundtripTest, std_vector_overloads) { + std_vec_round_trip(false, false); + std_vec_round_trip(false, true); + std_vec_round_trip(true, false); + std_vec_round_trip(true, true); } From 3f596d0daae7d2dd1de84d91bb53db0321824bd5 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 11 Apr 2023 15:29:47 -0400 Subject: [PATCH 5/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/model/model_base.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stan/model/model_base.hpp b/src/stan/model/model_base.hpp index f955dfeec35..2aaaeb3a0e8 100644 --- a/src/stan/model/model_base.hpp +++ b/src/stan/model/model_base.hpp @@ -368,7 +368,6 @@ class model_base : public prob_grad { bool include_tparams = true, bool include_gqs = true, std::ostream* msgs = 0) const = 0; - /** * Convert the specified sequence of constrained parameters to a * sequence of unconstrained parameters. @@ -632,9 +631,9 @@ class model_base : public prob_grad { * @param[in,out] params_r unconstrained parameters produced * @param[in,out] msgs msgs stream to which messages are written */ - virtual void unconstrain_array(const std::vector& params_constrained_r, - std::vector& params_r, - std::ostream* msgs = nullptr) const = 0; + virtual void unconstrain_array( + const std::vector& params_constrained_r, + std::vector& params_r, std::ostream* msgs = nullptr) const = 0; #ifdef STAN_MODEL_FVAR_VAR From 02630f1e053471d06c40f8a3c0b92bbcd3c90f66 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 4 May 2023 10:21:14 -0400 Subject: [PATCH 6/9] Add to model_base_crtp --- src/stan/model/model_base_crtp.hpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/stan/model/model_base_crtp.hpp b/src/stan/model/model_base_crtp.hpp index 96f85a5b3e0..244a5fa5849 100644 --- a/src/stan/model/model_base_crtp.hpp +++ b/src/stan/model/model_base_crtp.hpp @@ -141,6 +141,13 @@ class model_base_crtp : public stan::model::model_base { rng, theta, vars, include_tparams, include_gqs, msgs); } + void unconstrain_array(const Eigen::VectorXd& params_constrained_r, + Eigen::VectorXd& params_r, + std::ostream* msgs = nullptr) const override { + return static_cast(this)->unconstrain_array(params_constrained_r, + params_r, msgs); + } + // TODO(carpenter): remove redundant std::vector methods below here ===== // ====================================================================== @@ -203,6 +210,20 @@ class model_base_crtp : public stan::model::model_base { rng, theta, theta_i, vars, include_tparams, include_gqs, msgs); } + void void unconstrain_array(const std::vector& params_constrained_r, + std::vector& params_r, + std::ostream* msgs = nullptr) const override { + return static_cast(this)->unconstrain_array(params_constrained_r, + params_r, msgs); + } + + void transform_inits(const io::var_context& context, + Eigen::VectorXd& params_r, + std::ostream* msgs) const override { + return static_cast(this)->transform_inits(context, params_r, + msgs); + } + #ifdef STAN_MODEL_FVAR_VAR /** From e48eac567930a3f50836b1f55c3173634735df32 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 4 May 2023 10:21:36 -0400 Subject: [PATCH 7/9] Typo fix --- src/stan/model/model_base_crtp.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stan/model/model_base_crtp.hpp b/src/stan/model/model_base_crtp.hpp index 244a5fa5849..f866467b99a 100644 --- a/src/stan/model/model_base_crtp.hpp +++ b/src/stan/model/model_base_crtp.hpp @@ -210,7 +210,7 @@ class model_base_crtp : public stan::model::model_base { rng, theta, theta_i, vars, include_tparams, include_gqs, msgs); } - void void unconstrain_array(const std::vector& params_constrained_r, + void unconstrain_array(const std::vector& params_constrained_r, std::vector& params_r, std::ostream* msgs = nullptr) const override { return static_cast(this)->unconstrain_array(params_constrained_r, From aeb9eae6b05c03e45a082fdf5acae600c384044f Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 4 May 2023 10:22:23 -0400 Subject: [PATCH 8/9] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/model/model_base_crtp.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stan/model/model_base_crtp.hpp b/src/stan/model/model_base_crtp.hpp index f866467b99a..90a4d9ca359 100644 --- a/src/stan/model/model_base_crtp.hpp +++ b/src/stan/model/model_base_crtp.hpp @@ -211,8 +211,8 @@ class model_base_crtp : public stan::model::model_base { } void unconstrain_array(const std::vector& params_constrained_r, - std::vector& params_r, - std::ostream* msgs = nullptr) const override { + std::vector& params_r, + std::ostream* msgs = nullptr) const override { return static_cast(this)->unconstrain_array(params_constrained_r, params_r, msgs); } From 02583ce4f140d218be25f3da63460d674f907a6b Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 4 May 2023 10:47:52 -0400 Subject: [PATCH 9/9] Clean up roundtrip test --- .../model/array_functions_roundtrip_test.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/test/unit/model/array_functions_roundtrip_test.cpp b/src/test/unit/model/array_functions_roundtrip_test.cpp index 1164aabe22d..59d43892dac 100644 --- a/src/test/unit/model/array_functions_roundtrip_test.cpp +++ b/src/test/unit/model/array_functions_roundtrip_test.cpp @@ -5,23 +5,21 @@ #include #include +auto get_init_json() { + std::vector json_path = { + "src", "test", "test-models", "good", "model", "parameters.inits.json"}; + std::string filename = paths_to_fname(json_path); + std::ifstream in(filename); + return std::unique_ptr(new stan::json::json_data(in)); +} + class ModelArrayFunctionsRoundtripTest : public testing::Test { public: ModelArrayFunctionsRoundtripTest() - : model(context, 0, nullptr), rng(12324232), inits(nullptr) { - out.str(""); - - std::vector json_path = { - "src", "test", "test-models", "good", "model", "parameters.inits.json"}; - std::string filename = paths_to_fname(json_path); - std::ifstream in(filename); - inits = new stan::json::json_data(in); - } - - ~ModelArrayFunctionsRoundtripTest() { delete inits; } + : model(context, 0, nullptr), rng(12324232), inits(get_init_json()) {} stan::io::empty_var_context context; - stan::io::var_context* inits; + std::unique_ptr inits; stan_model model; std::stringstream out; boost::ecuyer1988 rng;