Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only swallow domain_errors in various algorithms #3259

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class base_hamiltonian {
void update_potential(Point& z, callbacks::logger& logger) {
try {
z.V = -stan::model::log_prob_propto<true>(model_, z.q);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
this->write_error_msg_(e, logger);
z.V = std::numeric_limits<double>::infinity();
}
Expand All @@ -62,7 +62,7 @@ class base_hamiltonian {
try {
stan::model::gradient(model_, z.q, z.V, z.g, logger);
z.V = -z.V;
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
this->write_error_msg_(e, logger);
z.V = std::numeric_limits<double>::infinity();
}
Expand Down
4 changes: 2 additions & 2 deletions src/stan/optimization/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class ModelAdaptor {

try {
f = -log_prob_propto<jacobian>(_model, _x, _params_i, _msgs);
} catch (const std::exception &e) {
} catch (const std::domain_error &e) {
if (_msgs)
(*_msgs) << e.what() << std::endl;
return 1;
Expand Down Expand Up @@ -341,7 +341,7 @@ class ModelAdaptor {

try {
f = -log_prob_grad<true, jacobian>(_model, _x, _params_i, _g, _msgs);
} catch (const std::exception &e) {
} catch (const std::domain_error &e) {
if (_msgs)
(*_msgs) << e.what() << std::endl;
return 1;
Expand Down
2 changes: 1 addition & 1 deletion src/stan/optimization/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ double newton_step(M& model, std::vector<double>& params_r,
try {
f1 = stan::model::log_prob_grad<true, jacobian>(model, new_params_r,
params_i, gradient);
} catch (std::exception& e) {
} catch (std::domain_error& e) {
// FIXME: this is not a good way to handle a general exception
f1 = -1e100;
}
Expand Down
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/fullrank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ int fullrank(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/meanfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ int meanfield(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
42 changes: 37 additions & 5 deletions src/stan/services/optimize/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand All @@ -119,7 +128,13 @@ int bfgs(Model& model, const stan::io::var_context& init,
" # evals"
" Notes ");

ret = bfgs.step();
try {
ret = bfgs.step();
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to the outer part of the while loop?


lp = bfgs.logp();
bfgs.params_r(cont_vector);

Expand Down Expand Up @@ -150,8 +165,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
// This if is here to match the pre-refactor behavior
if (msg.str().length() > 0)
logger.info(msg);
Expand All @@ -164,7 +187,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);
values.insert(values.begin(), lp);
Expand Down
30 changes: 26 additions & 4 deletions src/stan/services/optimize/lbfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ int lbfgs(Model& model, const stan::io::var_context& init,
" # evals"
" Notes ");

ret = lbfgs.step();
try {
ret = lbfgs.step();
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Generally for any while loops I'd like to just have the try on the outside

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved both

lp = lbfgs.logp();
lbfgs.params_r(cont_vector);

Expand Down Expand Up @@ -154,8 +159,16 @@ int lbfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand All @@ -167,7 +180,16 @@ int lbfgs(Model& model, const stan::io::var_context& init,
if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ int newton(Model& model, const stan::io::var_context& init,
lp = model.template log_prob<false, jacobian>(cont_vector, disc_vector,
&message);
logger.info(message);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
logger.info("");
logger.info(
"Informational Message: The current"
Expand Down
49 changes: 27 additions & 22 deletions src/stan/services/pathfinder/multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,33 @@ inline int pathfinder_lbfgs_multi(
individual_samples;
individual_samples.resize(num_paths);
std::atomic<size_t> lp_calls{0};
tbb::parallel_for(
tbb::blocked_range<int>(0, num_paths), [&](tbb::blocked_range<int> r) {
for (int iter = r.begin(); iter < r.end(); ++iter) {
auto pathfinder_ret
= stan::services::pathfinder::pathfinder_lbfgs_single<true>(
model, *(init[iter]), random_seed, stride_id + iter,
init_radius, history_size, init_alpha, tol_obj, tol_rel_obj,
tol_grad, tol_rel_grad, tol_param, num_iterations,
num_elbo_draws, num_draws, save_iterations, refresh,
interrupt, logger, init_writers[iter],
single_path_parameter_writer[iter],
single_path_diagnostic_writer[iter], calculate_lp);
if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) {
logger.error(std::string("Pathfinder iteration: ")
+ std::to_string(iter) + " failed.");
return;
try {
tbb::parallel_for(
tbb::blocked_range<int>(0, num_paths), [&](tbb::blocked_range<int> r) {
for (int iter = r.begin(); iter < r.end(); ++iter) {
auto pathfinder_ret
= stan::services::pathfinder::pathfinder_lbfgs_single<true>(
model, *(init[iter]), random_seed, stride_id + iter,
init_radius, history_size, init_alpha, tol_obj, tol_rel_obj,
tol_grad, tol_rel_grad, tol_param, num_iterations,
num_elbo_draws, num_draws, save_iterations, refresh,
interrupt, logger, init_writers[iter],
single_path_parameter_writer[iter],
single_path_diagnostic_writer[iter], calculate_lp);
if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) {
logger.error(std::string("Pathfinder iteration: ")
+ std::to_string(iter) + " failed.");
SteveBronder marked this conversation as resolved.
Show resolved Hide resolved
return;
}
individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret));
individual_samples[iter] = std::move(std::get<2>(pathfinder_ret));
lp_calls += std::get<3>(pathfinder_ret);
}
individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret));
individual_samples[iter] = std::move(std::get<2>(pathfinder_ret));
lp_calls += std::get<3>(pathfinder_ret);
}
});
});
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

// if any pathfinders failed, we want to remove their empty results
individual_lp_ratios.erase(
Expand Down Expand Up @@ -231,7 +236,7 @@ inline int pathfinder_lbfgs_multi(
parameter_writer(total_time_str);
}
parameter_writer();
return 0;
return error_codes::OK;
}
} // namespace pathfinder
} // namespace services
Expand Down
17 changes: 13 additions & 4 deletions src/stan/services/pathfinder/single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ inline elbo_est_t est_approx_draws(LPF&& lp_fun, ConstrainF&& constrain_fun,
approx_samples_col = approx_samples.col(i);
++lp_fun_calls;
lp_mat.coeffRef(i, 1) = lp_fun(approx_samples_col, pathfinder_ss);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
lp_mat.coeffRef(i, 1) = -std::numeric_limits<double>::infinity();
}
log_stream(logger, pathfinder_ss, iter_msg);
Expand Down Expand Up @@ -530,7 +530,7 @@ auto pathfinder_impl(RNG&& rng, LPFun&& lp_fun, ConstrainFun&& constrain_fun,
lp_fun, constrain_fun, rng, taylor_appx,
num_elbo_draws, alpha, iter_msg, logger),
taylor_appx);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
logger.warn(iter_msg + "ELBO estimation failed "
+ " with error: " + e.what());
return std::make_pair(internal::elbo_est_t{}, internal::taylor_approx_t{});
Expand Down Expand Up @@ -820,7 +820,16 @@ inline auto pathfinder_lbfgs_single(
logger.info(lbfgs_ss);
lbfgs_ss.str("");
}
throw e;
if (ReturnLpSamples) {
// we want to terminate multi-path pathfinder during these unrecoverable
// exceptions
throw;
Comment on lines +824 to +826
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? One pathfinder can fail while the others succeed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one path does something unrecoverable like indexes into a vector out of bounds, I think the right thing to do is stop and say "something is clearly wrong with your model".

Besides that, a primary motivation behind this PR is also to implement an exit() function into the language, which requires the ability to terminate the entire algorithm it is in. If one chain or one path calls exit, the entire thing still needs to halt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one path does something unrecoverable like indexes into a vector out of bounds, I think the right thing to do is stop and say "something is clearly wrong with your model".

Are only unrecoverable errors able to be thrown here? Like if the call to log_prob throws it could be from something recoverable like one of the ode solvers failing for a given set of parameters

Copy link
Member Author

@WardBrian WardBrian Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pathfinder_impl catches std::domain_errors, which is what the ODE solvers throw (and is generally what the math library considers 'recoverable'), so the only exceptions which reach this will be unrecoverable ones from the model or issues in our algorithm

} else {
logger.error(e.what());
SteveBronder marked this conversation as resolved.
Show resolved Hide resolved
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::SOFTWARE, Eigen::Array<double, Eigen::Dynamic, 1>(0),
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>(0, 0), 0);
}
}
}
if (unlikely(save_iterations)) {
Expand Down Expand Up @@ -900,7 +909,7 @@ inline auto pathfinder_lbfgs_single(
approx_samples_constrained_col)
.matrix();
}
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
std::string err_msg = e.what();
logger.warn(path_num + "Final sampling approximation failed with error: "
+ err_msg);
Expand Down
58 changes: 34 additions & 24 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ int fixed_param(Model& model, const stan::io::var_context& init,
writer.write_diagnostic_names(s, sampler, model);

auto start = std::chrono::steady_clock::now();
util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin,
refresh, true, false, writer, s, model, rng,
interrupt, logger);
try {
util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin,
refresh, true, false, writer, s, model, rng,
interrupt, logger);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
Expand Down Expand Up @@ -156,27 +161,32 @@ int fixed_param(Model& model, const std::size_t num_chains,
writers[i].write_diagnostic_names(samples[i], samplers[i], model);
}

tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger,
num_samples, num_thin, refresh, chain,
num_chains](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
auto start = std::chrono::steady_clock::now();
util::generate_transitions(samplers[i], num_samples, 0, num_samples,
num_thin, refresh, true, false, writers[i],
samples[i], model, rngs[i], interrupt,
logger, chain + i, num_chains);
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end
- start)
.count()
/ 1000.0;
writers[i].write_timing(0.0, sample_delta_t);
}
},
tbb::simple_partitioner());
try {
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger,
num_samples, num_thin, refresh, chain,
num_chains](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
auto start = std::chrono::steady_clock::now();
util::generate_transitions(
samplers[i], num_samples, 0, num_samples, num_thin, refresh,
true, false, writers[i], samples[i], model, rngs[i], interrupt,
logger, chain + i, num_chains);
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end
- start)
.count()
/ 1000.0;
writers[i].write_timing(0.0, sample_delta_t);
}
},
tbb::simple_partitioner());
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
return error_codes::OK;
}

Expand Down
Loading
Loading