-
-
Notifications
You must be signed in to change notification settings - Fork 368
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3179 from stan-dev/model-base-unconstrain-array
Expose new base class method 'unconstrain_array'
- Loading branch information
Showing
10 changed files
with
386 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
{ | ||
"theta": 3.35045497556837, | ||
"sigma": [ 1.0678931954592095, 0.14205376411270393, 2.7743965352097821 ], | ||
"mu": [ 2.1285977303026997, 2.3570177773034975, 5.6279470675750378 ], | ||
"alpha": [ | ||
[ | ||
[ | ||
0.12208382701175115, 0.096277220888585938, 0.37179118427596042, | ||
0.40984776782370236 | ||
], | ||
[ | ||
0.41367090546791735, 0.4243728017966204, 0.074953157708614424, | ||
0.087003135026847922 | ||
], | ||
[ | ||
0.37512017100653144, 0.22485649523657397, 0.1430831011691662, | ||
0.2569402325877283 | ||
] | ||
], | ||
[ | ||
[ | ||
0.2724045704160345, 0.36283264933917447, 0.19251820791261157, | ||
0.17224457233217949 | ||
], | ||
[ | ||
0.10266646872740415, 0.11717509867472625, 0.5250961979218981, | ||
0.2550622346759715 | ||
], | ||
[ | ||
0.052691526059705653, 0.39621004190419407, 0.1746323105072301, | ||
0.3764661215288701 | ||
] | ||
] | ||
], | ||
"cm": [ | ||
[ | ||
[ 2.2161729152433294, 4.2735083903241762 ], | ||
[ 2.3228288774921975, 4.2551781219057006 ], | ||
[ 4.0817896796501536, 4.4410534765422955 ], | ||
[ 5.777652580986282, 2.702994187045916 ] | ||
], | ||
[ | ||
[ 4.4021064947227728, 6.2791571000818873 ], | ||
[ 2.3751795650893412, 3.2065345846875557 ], | ||
[ 3.333808860667193, 2.6610201337431665 ], | ||
[ 2.6936180212748866, 5.8592677493484269 ] | ||
], | ||
[ | ||
[ 4.1249059398888281, 4.6356500684062221 ], | ||
[ 4.1807348015868211, 6.38945134357675 ], | ||
[ 5.6716804030142018, 6.9022942821354389 ], | ||
[ 3.1672905280367103, 4.8241949183238049 ] | ||
] | ||
], | ||
"L_Omega": [ | ||
[ 1.9177988996943971, 0.0, 0.0 ], | ||
[ 1.4456130219514518, 1.0865133122191057, 0.0 ], | ||
[ 1.6437414070749543, 0.50344658420634625, 1.7675118195549122 ] | ||
], | ||
"L_Corr": [ | ||
[ 1.0, 0.0, 0.0 ], | ||
[ 0.84721537574654238, 0.5312495713867873, 0.0 ], | ||
[ 0.60748819602989068, 0.60437187989848917, 0.51545389946368159 ] | ||
], | ||
"Omega": [ | ||
[ | ||
[ | ||
[ | ||
[ 1.0, 0.68939029205051539, 0.504611238921639 ], | ||
[ 0.68939029205051539, 1.0000000000000002, 0.78401979256357612 ], | ||
[ 0.504611238921639, 0.78401979256357612, 1.0 ] | ||
], | ||
[ | ||
[ 1.0, 0.74541410105288353, 0.30190927376354765 ], | ||
[ 0.74541410105288353, 1.0, 0.44988416836084927 ], | ||
[ 0.30190927376354765, 0.44988416836084927, 0.99999999999999978 ] | ||
] | ||
], | ||
[ | ||
[ | ||
[ 1.0, 0.55326714683507616, 0.12986029237307539 ], | ||
[ 0.55326714683507616, 1.0, 0.16728991788128289 ], | ||
[ 0.12986029237307539, 0.16728991788128289, 1.0000000000000002 ] | ||
], | ||
[ | ||
[ 1.0, 0.721237684869733, 0.846045734857208 ], | ||
[ 0.721237684869733, 1.0, 0.76818273298484252 ], | ||
[ 0.846045734857208, 0.76818273298484252, 0.99999999999999989 ] | ||
] | ||
], | ||
[ | ||
[ | ||
[ 1.0, 0.77072046015371076, 0.19772724769821418 ], | ||
[ 0.77072046015371076, 1.0, 0.77299666827811075 ], | ||
[ 0.19772724769821418, 0.77299666827811075, 1.0 ] | ||
], | ||
[ | ||
[ 1.0, 0.77442690127950775, 0.34050319470228807 ], | ||
[ 0.77442690127950775, 1.0, 0.33517315284975119 ], | ||
[ 0.34050319470228807, 0.33517315284975119, 0.99999999999999989 ] | ||
] | ||
] | ||
], | ||
[ | ||
[ | ||
[ | ||
[ 1.0, 0.824858543378145, 0.48649005735757422 ], | ||
[ 0.824858543378145, 1.0000000000000002, 0.68242926355025413 ], | ||
[ 0.48649005735757422, 0.68242926355025413, 0.99999999999999989 ] | ||
], | ||
[ | ||
[ 1.0, 0.70338407009562276, 0.57632477902590085 ], | ||
[ 0.70338407009562276, 1.0000000000000002, 0.494867050063337 ], | ||
[ 0.57632477902590085, 0.494867050063337, 0.99999999999999989 ] | ||
] | ||
], | ||
[ | ||
[ | ||
[ 1.0, 0.57297498396235214, 0.41285332825575527 ], | ||
[ 0.57297498396235214, 0.99999999999999989, 0.58082681602684771 ], | ||
[ 0.41285332825575527, 0.58082681602684771, 0.99999999999999989 ] | ||
], | ||
[ | ||
[ 1.0, 0.96800009996808856, 0.77013498444134976 ], | ||
[ 0.96800009996808856, 1.0, 0.75291247044061227 ], | ||
[ 0.77013498444134976, 0.75291247044061227, 0.99999999999999978 ] | ||
] | ||
], | ||
[ | ||
[ | ||
[ 1.0, 0.31657617071814048, 0.036375889145384352 ], | ||
[ 0.31657617071814048, 1.0, 0.89147176254678817 ], | ||
[ 0.036375889145384352, 0.89147176254678817, 1.0000000000000002 ] | ||
], | ||
[ | ||
[ 1.0, 0.999993509817939, 0.25954691874754277 ], | ||
[ 0.999993509817939, 0.99999999999999989, 0.262726899218798 ], | ||
[ 0.25954691874754277, 0.262726899218798, 0.99999999999999978 ] | ||
] | ||
] | ||
] | ||
], | ||
"cv": [ | ||
[ | ||
[ | ||
[ | ||
[ 3.5433970587540458, 6.2950770670120129 ], | ||
[ 5.6201329242297717, 2.3146297090598931 ], | ||
[ 5.1462314404870781, 3.2721399347170417 ], | ||
[ 4.0656750934461474, 3.2411893901314164 ] | ||
], | ||
[ | ||
[ 6.6251907292217185, 6.9002416987352486 ], | ||
[ 5.0112130349146735, 3.6900697736636112 ], | ||
[ 2.2371032314273811, 5.2247184067767574 ], | ||
[ 2.17068504261216, 2.2456130023389376 ] | ||
], | ||
[ | ||
[ 5.80176022175848, 5.91129971601706 ], | ||
[ 4.3864026258516891, 3.798097280619011 ], | ||
[ 2.7099329809140325, 5.8044690521155928 ], | ||
[ 4.4681514496492873, 4.0817140157307108 ] | ||
] | ||
], | ||
[ | ||
[ | ||
[ 4.8785830147136737, 4.3952007395396215 ], | ||
[ 3.6619384086120315, 2.8332738265524342 ], | ||
[ 6.2210444243377845, 2.2427497248475379 ], | ||
[ 3.7985260179585287, 2.8924615842522448 ] | ||
], | ||
[ | ||
[ 6.9248095019047611, 2.0818481647367837 ], | ||
[ 2.1600854255133495, 4.6074404412211223 ], | ||
[ 4.0638348657136438, 4.2456437508499167 ], | ||
[ 6.8999181654782307, 2.6275592155970942 ] | ||
], | ||
[ | ||
[ 3.9680861372977834, 5.7360991610562344 ], | ||
[ 4.18537248038814, 4.2370654246554018 ], | ||
[ 6.9967852481487673, 3.2049914753910862 ], | ||
[ 5.1202970072734892, 3.2138559765944259 ] | ||
] | ||
] | ||
] | ||
], | ||
"p": 0.44452773799726492 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// used to test parameter serialization/deserialization | ||
parameters { | ||
real theta; | ||
array[3] real<lower=0> sigma; | ||
vector[3] mu; | ||
array[2, 3] simplex[4] alpha; | ||
complex_matrix[3, 4] cm; | ||
cholesky_factor_cov[3] L_Omega; | ||
cholesky_factor_corr[3] L_Corr; | ||
array[2, 3, 2] corr_matrix[3] Omega; | ||
array[1, 2, 3] complex_vector[4] cv; | ||
real<lower=0, upper=1> p; | ||
} | ||
transformed parameters { | ||
vector[3] mu2; | ||
mu2 = mu + 1; | ||
} | ||
generated quantities { | ||
array[3] real y; | ||
y = normal_rng(mu2, sigma); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include <stan/model/log_prob_grad.hpp> | ||
#include <stan/io/empty_var_context.hpp> | ||
#include <stan/io/json/json_data.hpp> | ||
#include <test/test-models/good/model/parameters.hpp> | ||
#include <test/unit/util.hpp> | ||
#include <gtest/gtest.h> | ||
|
||
auto get_init_json() { | ||
std::vector<std::string> json_path = { | ||
"src", "test", "test-models", "good", "model", "parameters.inits.json"}; | ||
std::string filename = paths_to_fname(json_path); | ||
std::ifstream in(filename); | ||
return std::unique_ptr<stan::io::var_context>(new stan::json::json_data(in)); | ||
} | ||
|
||
class ModelArrayFunctionsRoundtripTest : public testing::Test { | ||
public: | ||
ModelArrayFunctionsRoundtripTest() | ||
: model(context, 0, nullptr), rng(12324232), inits(get_init_json()) {} | ||
|
||
stan::io::empty_var_context context; | ||
std::unique_ptr<stan::io::var_context> inits; | ||
stan_model model; | ||
std::stringstream out; | ||
boost::ecuyer1988 rng; | ||
|
||
/** | ||
* Test that the unconstrain_array function is the inverse of the | ||
* write_array function. This tests the Eigen overloads. | ||
* | ||
* This calls transform_inits, write_array, and then unconstrain_array | ||
* and asserts that the output of unconstrain_array is the same as the | ||
* output of transform_inits. | ||
*/ | ||
void eigen_round_trip(bool include_gq, bool include_tp) { | ||
Eigen::VectorXd init_vector; | ||
model.transform_inits(*inits, init_vector, &out); | ||
|
||
Eigen::VectorXd written_vector; | ||
model.write_array(rng, init_vector, written_vector, include_gq, include_tp, | ||
&out); | ||
|
||
Eigen::VectorXd recovered_vector; | ||
model.unconstrain_array(written_vector, recovered_vector, &out); | ||
|
||
EXPECT_MATRIX_NEAR(init_vector, recovered_vector, 1e-10); | ||
EXPECT_EQ("", out.str()); | ||
} | ||
|
||
/** | ||
* Same as eigen_round_trip but for the std::vector overloads | ||
*/ | ||
void std_vec_round_trip(bool include_gq, bool include_tp) { | ||
std::vector<int> unused; | ||
std::vector<double> init_vector; | ||
model.transform_inits(*inits, unused, init_vector, &out); | ||
|
||
std::vector<double> written_vector; | ||
model.write_array(rng, init_vector, unused, written_vector, include_gq, | ||
include_tp, &out); | ||
|
||
std::vector<double> recovered_vector; | ||
model.unconstrain_array(written_vector, recovered_vector, &out); | ||
|
||
for (int i = 0; i < init_vector.size(); i++) { | ||
EXPECT_NEAR(init_vector[i], recovered_vector[i], 1e-10); | ||
} | ||
|
||
EXPECT_EQ("", out.str()); | ||
} | ||
}; | ||
|
||
// test all combinations of include_gq and include_tp. | ||
// unconstrain_array should ignore them as they appear at the end | ||
// of the written vectors | ||
|
||
TEST_F(ModelArrayFunctionsRoundtripTest, eigen_overloads) { | ||
eigen_round_trip(false, false); | ||
eigen_round_trip(false, true); | ||
eigen_round_trip(true, false); | ||
eigen_round_trip(true, true); | ||
} | ||
|
||
TEST_F(ModelArrayFunctionsRoundtripTest, std_vector_overloads) { | ||
std_vec_round_trip(false, false); | ||
std_vec_round_trip(false, true); | ||
std_vec_round_trip(true, false); | ||
std_vec_round_trip(true, true); | ||
} |
Oops, something went wrong.