From 643fd8b5ba649d1b3a849517fe1ab456e6917929 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 24 Jan 2024 16:24:32 -0500 Subject: [PATCH 1/6] Only swallow domain_errors --- src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp | 4 ++-- src/stan/services/pathfinder/single.hpp | 6 +++--- src/stan/services/util/gq_writer.hpp | 12 +++++++++++- src/stan/services/util/mcmc_writer.hpp | 7 ++++++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp b/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp index ab47e01068d..abb575bfb60 100644 --- a/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp @@ -52,7 +52,7 @@ class base_hamiltonian { void update_potential(Point& z, callbacks::logger& logger) { try { z.V = -stan::model::log_prob_propto(model_, z.q); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { this->write_error_msg_(e, logger); z.V = std::numeric_limits::infinity(); } @@ -62,7 +62,7 @@ class base_hamiltonian { try { stan::model::gradient(model_, z.q, z.V, z.g, logger); z.V = -z.V; - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { this->write_error_msg_(e, logger); z.V = std::numeric_limits::infinity(); } diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index fb7f3295a46..8536178f4f7 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -250,7 +250,7 @@ inline elbo_est_t est_approx_draws(LPF&& lp_fun, ConstrainF&& constrain_fun, approx_samples_col = approx_samples.col(i); ++lp_fun_calls; lp_mat.coeffRef(i, 1) = lp_fun(approx_samples_col, pathfinder_ss); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { lp_mat.coeffRef(i, 1) = -std::numeric_limits::infinity(); } log_stream(logger, pathfinder_ss, iter_msg); @@ -530,7 +530,7 @@ auto pathfinder_impl(RNG&& rng, LPFun&& lp_fun, ConstrainFun&& constrain_fun, lp_fun, constrain_fun, rng, taylor_appx, num_elbo_draws, alpha, iter_msg, logger), taylor_appx); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { logger.warn(iter_msg + "ELBO estimation failed " + " with error: " + e.what()); return std::make_pair(internal::elbo_est_t{}, internal::taylor_approx_t{}); @@ -900,7 +900,7 @@ inline auto pathfinder_lbfgs_single( approx_samples_constrained_col) .matrix(); } - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { std::string err_msg = e.what(); logger.warn(path_num + "Final sampling approximation failed with error: " + err_msg); diff --git a/src/stan/services/util/gq_writer.hpp b/src/stan/services/util/gq_writer.hpp index c844a54dd22..5cb8e6eb78a 100644 --- a/src/stan/services/util/gq_writer.hpp +++ b/src/stan/services/util/gq_writer.hpp @@ -79,10 +79,15 @@ class gq_writer { model.write_array(rng, draw, params_i, values, false, true, &ss); if (ss.str().length() > 0) logger_.info(ss); + } catch (const std::domain_error& e) { + if (ss.str().length() > 0) + logger_.info(ss); + logger_.info(e.what()); } catch (const std::exception& e) { if (ss.str().length() > 0) logger_.info(ss); logger_.info(e.what()); + throw; } std::vector gq_values(values.begin() + num_constrained_params_, @@ -110,11 +115,16 @@ class gq_writer { if (ss.str().length() > 0) { logger_.info(ss); } - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { if (ss.str().length() > 0) { logger_.info(ss); } logger_.info(e.what()); + } catch (const std::exception& e) { + if (ss.str().length() > 0) + logger_.info(ss); + logger_.info(e.what()); + throw; } sample_writer_(values); } diff --git a/src/stan/services/util/mcmc_writer.hpp b/src/stan/services/util/mcmc_writer.hpp index d51f168c9bc..7f5d5661858 100644 --- a/src/stan/services/util/mcmc_writer.hpp +++ b/src/stan/services/util/mcmc_writer.hpp @@ -111,11 +111,16 @@ class mcmc_writer { sample.cont_params().data() + sample.cont_params().size()); model.write_array(rng, cont_params, params_i, model_values, true, true, &ss); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { if (ss.str().length() > 0) logger_.info(ss); ss.str(""); logger_.info(e.what()); + } catch (const std::exception& e) { + if (ss.str().length() > 0) + logger_.info(ss); + logger_.info(e.what()); + throw; } if (ss.str().length() > 0) logger_.info(ss); From 5681a8182fb5131a8f31d08934833037461a0b14 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 25 Jan 2024 09:45:34 -0500 Subject: [PATCH 2/6] A few more places --- src/stan/optimization/bfgs.hpp | 4 ++-- src/stan/optimization/newton.hpp | 2 +- src/stan/services/pathfinder/single.hpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stan/optimization/bfgs.hpp b/src/stan/optimization/bfgs.hpp index 53e7fa075bd..86a5a20ccae 100644 --- a/src/stan/optimization/bfgs.hpp +++ b/src/stan/optimization/bfgs.hpp @@ -309,7 +309,7 @@ class ModelAdaptor { try { f = -log_prob_propto(_model, _x, _params_i, _msgs); - } catch (const std::exception &e) { + } catch (const std::domain_error &e) { if (_msgs) (*_msgs) << e.what() << std::endl; return 1; @@ -341,7 +341,7 @@ class ModelAdaptor { try { f = -log_prob_grad(_model, _x, _params_i, _g, _msgs); - } catch (const std::exception &e) { + } catch (const std::domain_error &e) { if (_msgs) (*_msgs) << e.what() << std::endl; return 1; diff --git a/src/stan/optimization/newton.hpp b/src/stan/optimization/newton.hpp index 624a8c8dbc1..b6f24d3eb08 100644 --- a/src/stan/optimization/newton.hpp +++ b/src/stan/optimization/newton.hpp @@ -60,7 +60,7 @@ double newton_step(M& model, std::vector& params_r, try { f1 = stan::model::log_prob_grad(model, new_params_r, params_i, gradient); - } catch (std::exception& e) { + } catch (std::domain_error& e) { // FIXME: this is not a good way to handle a general exception f1 = -1e100; } diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index 8536178f4f7..0f0c7457ba0 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -820,7 +820,7 @@ inline auto pathfinder_lbfgs_single( logger.info(lbfgs_ss); lbfgs_ss.str(""); } - throw e; + throw; } } if (unlikely(save_iterations)) { From 2709df7f11af708477088f5609f6f70163ee5135 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 25 Jan 2024 10:40:49 -0500 Subject: [PATCH 3/6] Maintain exception safety in top-level services functions --- .../services/experimental/advi/fullrank.hpp | 9 +- .../services/experimental/advi/meanfield.hpp | 9 +- src/stan/services/optimize/bfgs.hpp | 42 +++++++-- src/stan/services/optimize/lbfgs.hpp | 30 ++++++- src/stan/services/optimize/newton.hpp | 2 +- src/stan/services/pathfinder/multi.hpp | 49 ++++++----- src/stan/services/pathfinder/single.hpp | 7 +- src/stan/services/sample/fixed_param.hpp | 58 +++++++------ src/stan/services/sample/hmc_nuts_dense_e.hpp | 44 ++++++---- .../sample/hmc_nuts_dense_e_adapt.hpp | 49 ++++++----- src/stan/services/sample/hmc_nuts_diag_e.hpp | 12 ++- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 49 ++++++----- src/stan/services/sample/hmc_nuts_unit_e.hpp | 47 ++++++---- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 49 ++++++----- .../services/sample/hmc_static_dense_e.hpp | 11 ++- .../sample/hmc_static_dense_e_adapt.hpp | 13 ++- .../services/sample/hmc_static_diag_e.hpp | 12 ++- .../sample/hmc_static_diag_e_adapt.hpp | 14 ++- .../services/sample/hmc_static_unit_e.hpp | 11 ++- .../sample/hmc_static_unit_e_adapt.hpp | 13 ++- src/stan/services/sample/standalone_gqs.hpp | 86 +++++++++++-------- 21 files changed, 395 insertions(+), 221 deletions(-) diff --git a/src/stan/services/experimental/advi/fullrank.hpp b/src/stan/services/experimental/advi/fullrank.hpp index 96cf5d05c0b..5fba2e4e026 100644 --- a/src/stan/services/experimental/advi/fullrank.hpp +++ b/src/stan/services/experimental/advi/fullrank.hpp @@ -86,8 +86,13 @@ int fullrank(Model& model, const stan::io::var_context& init, stan::rng_t> cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo, output_samples); - cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, - max_iterations, logger, parameter_writer, diagnostic_writer); + try { + cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, + max_iterations, logger, parameter_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return stan::services::error_codes::OK; } diff --git a/src/stan/services/experimental/advi/meanfield.hpp b/src/stan/services/experimental/advi/meanfield.hpp index 6cffe548acf..49bee285058 100644 --- a/src/stan/services/experimental/advi/meanfield.hpp +++ b/src/stan/services/experimental/advi/meanfield.hpp @@ -85,8 +85,13 @@ int meanfield(Model& model, const stan::io::var_context& init, stan::rng_t> cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo, output_samples); - cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, - max_iterations, logger, parameter_writer, diagnostic_writer); + try { + cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, + max_iterations, logger, parameter_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return stan::services::error_codes::OK; } diff --git a/src/stan/services/optimize/bfgs.hpp b/src/stan/services/optimize/bfgs.hpp index 2819b853a63..bcb0e49f31b 100644 --- a/src/stan/services/optimize/bfgs.hpp +++ b/src/stan/services/optimize/bfgs.hpp @@ -96,7 +96,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); @@ -119,7 +128,13 @@ int bfgs(Model& model, const stan::io::var_context& init, " # evals" " Notes "); - ret = bfgs.step(); + try { + ret = bfgs.step(); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } + lp = bfgs.logp(); bfgs.params_r(cont_vector); @@ -150,8 +165,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, - &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } // This if is here to match the pre-refactor behavior if (msg.str().length() > 0) logger.info(msg); @@ -164,7 +187,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (!save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); values.insert(values.begin(), lp); diff --git a/src/stan/services/optimize/lbfgs.hpp b/src/stan/services/optimize/lbfgs.hpp index 083e37ffed8..9045b5470e2 100644 --- a/src/stan/services/optimize/lbfgs.hpp +++ b/src/stan/services/optimize/lbfgs.hpp @@ -123,7 +123,12 @@ int lbfgs(Model& model, const stan::io::var_context& init, " # evals" " Notes "); - ret = lbfgs.step(); + try { + ret = lbfgs.step(); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } lp = lbfgs.logp(); lbfgs.params_r(cont_vector); @@ -154,8 +159,16 @@ int lbfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, - &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); @@ -167,7 +180,16 @@ int lbfgs(Model& model, const stan::io::var_context& init, if (!save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); diff --git a/src/stan/services/optimize/newton.hpp b/src/stan/services/optimize/newton.hpp index 081365f0a9c..db64f6e46c3 100644 --- a/src/stan/services/optimize/newton.hpp +++ b/src/stan/services/optimize/newton.hpp @@ -62,7 +62,7 @@ int newton(Model& model, const stan::io::var_context& init, lp = model.template log_prob(cont_vector, disc_vector, &message); logger.info(message); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { logger.info(""); logger.info( "Informational Message: The current" diff --git a/src/stan/services/pathfinder/multi.hpp b/src/stan/services/pathfinder/multi.hpp index e87eaa63e32..924f0806b37 100644 --- a/src/stan/services/pathfinder/multi.hpp +++ b/src/stan/services/pathfinder/multi.hpp @@ -117,28 +117,33 @@ inline int pathfinder_lbfgs_multi( individual_samples; individual_samples.resize(num_paths); std::atomic lp_calls{0}; - tbb::parallel_for( - tbb::blocked_range(0, num_paths), [&](tbb::blocked_range r) { - for (int iter = r.begin(); iter < r.end(); ++iter) { - auto pathfinder_ret - = stan::services::pathfinder::pathfinder_lbfgs_single( - model, *(init[iter]), random_seed, stride_id + iter, - init_radius, history_size, init_alpha, tol_obj, tol_rel_obj, - tol_grad, tol_rel_grad, tol_param, num_iterations, - num_elbo_draws, num_draws, save_iterations, refresh, - interrupt, logger, init_writers[iter], - single_path_parameter_writer[iter], - single_path_diagnostic_writer[iter], calculate_lp); - if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) { - logger.error(std::string("Pathfinder iteration: ") - + std::to_string(iter) + " failed."); - return; + try { + tbb::parallel_for( + tbb::blocked_range(0, num_paths), [&](tbb::blocked_range r) { + for (int iter = r.begin(); iter < r.end(); ++iter) { + auto pathfinder_ret + = stan::services::pathfinder::pathfinder_lbfgs_single( + model, *(init[iter]), random_seed, stride_id + iter, + init_radius, history_size, init_alpha, tol_obj, tol_rel_obj, + tol_grad, tol_rel_grad, tol_param, num_iterations, + num_elbo_draws, num_draws, save_iterations, refresh, + interrupt, logger, init_writers[iter], + single_path_parameter_writer[iter], + single_path_diagnostic_writer[iter], calculate_lp); + if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) { + logger.error(std::string("Pathfinder iteration: ") + + std::to_string(iter) + " failed."); + return; + } + individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret)); + individual_samples[iter] = std::move(std::get<2>(pathfinder_ret)); + lp_calls += std::get<3>(pathfinder_ret); } - individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret)); - individual_samples[iter] = std::move(std::get<2>(pathfinder_ret)); - lp_calls += std::get<3>(pathfinder_ret); - } - }); + }); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } // if any pathfinders failed, we want to remove their empty results individual_lp_ratios.erase( @@ -231,7 +236,7 @@ inline int pathfinder_lbfgs_multi( parameter_writer(total_time_str); } parameter_writer(); - return 0; + return error_codes::OK; } } // namespace pathfinder } // namespace services diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index 0f0c7457ba0..4719a40e1e9 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -820,7 +820,12 @@ inline auto pathfinder_lbfgs_single( logger.info(lbfgs_ss); lbfgs_ss.str(""); } - throw; + if (ReturnLpSamples) { + throw; + } else { + logger.error(e.what()); + return error_codes::SOFTWARE; + } } } if (unlikely(save_iterations)) { diff --git a/src/stan/services/sample/fixed_param.hpp b/src/stan/services/sample/fixed_param.hpp index f407b14f57f..17e04a621e2 100644 --- a/src/stan/services/sample/fixed_param.hpp +++ b/src/stan/services/sample/fixed_param.hpp @@ -74,9 +74,14 @@ int fixed_param(Model& model, const stan::io::var_context& init, writer.write_diagnostic_names(s, sampler, model); auto start = std::chrono::steady_clock::now(); - util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin, - refresh, true, false, writer, s, model, rng, - interrupt, logger); + try { + util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin, + refresh, true, false, writer, s, model, rng, + interrupt, logger); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } auto end = std::chrono::steady_clock::now(); double sample_delta_t = std::chrono::duration_cast(end - start) @@ -156,27 +161,32 @@ int fixed_param(Model& model, const std::size_t num_chains, writers[i].write_diagnostic_names(samples[i], samplers[i], model); } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger, - num_samples, num_thin, refresh, chain, - num_chains](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - auto start = std::chrono::steady_clock::now(); - util::generate_transitions(samplers[i], num_samples, 0, num_samples, - num_thin, refresh, true, false, writers[i], - samples[i], model, rngs[i], interrupt, - logger, chain + i, num_chains); - auto end = std::chrono::steady_clock::now(); - double sample_delta_t - = std::chrono::duration_cast(end - - start) - .count() - / 1000.0; - writers[i].write_timing(0.0, sample_delta_t); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger, + num_samples, num_thin, refresh, chain, + num_chains](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + auto start = std::chrono::steady_clock::now(); + util::generate_transitions( + samplers[i], num_samples, 0, num_samples, num_thin, refresh, + true, false, writers[i], samples[i], model, rngs[i], interrupt, + logger, chain + i, num_chains); + auto end = std::chrono::steady_clock::now(); + double sample_delta_t + = std::chrono::duration_cast(end + - start) + .count() + / 1000.0; + writers[i].write_timing(0.0, sample_delta_t); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_dense_e.hpp b/src/stan/services/sample/hmc_nuts_dense_e.hpp index 73e0a2af1d2..0fb818b19cb 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e.hpp @@ -81,9 +81,14 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -221,20 +226,25 @@ int hmc_nuts_dense_e(Model& model, size_t num_chains, logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, - num_samples, num_thin, refresh, save_warmup, - rngs[i], interrupt, logger, sample_writer[i], - diagnostic_writer[i], init_chain_id + i); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 913c50152af..ce1befff874 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -98,11 +98,15 @@ int hmc_nuts_dense_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); - - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -379,21 +383,26 @@ int hmc_nuts_dense_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_diag_e.hpp b/src/stan/services/sample/hmc_nuts_diag_e.hpp index e693ed0a834..bb789c01519 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e.hpp @@ -79,10 +79,14 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); - + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 1044d9ed539..ec48ade0b5d 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -99,11 +99,15 @@ int hmc_nuts_diag_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); - + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -379,21 +383,26 @@ int hmc_nuts_diag_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_unit_e.hpp b/src/stan/services/sample/hmc_nuts_unit_e.hpp index 01c9fe2e1b0..c7a2ca7c1bd 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e.hpp @@ -69,10 +69,14 @@ int hmc_nuts_unit_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); - + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -155,21 +159,26 @@ int hmc_nuts_unit_e(Model& model, size_t num_chains, logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, - num_samples, num_thin, refresh, save_warmup, - rngs[i], interrupt, logger, sample_writer[i], - diagnostic_writer[i], init_chain_id + i, - num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i, + num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 889c6d89200..5d74dae1767 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -83,11 +83,15 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); - + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -228,21 +232,26 @@ int hmc_nuts_unit_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_dense_e.hpp b/src/stan/services/sample/hmc_static_dense_e.hpp index c337636f9c7..c0931617383 100644 --- a/src/stan/services/sample/hmc_static_dense_e.hpp +++ b/src/stan/services/sample/hmc_static_dense_e.hpp @@ -76,9 +76,14 @@ int hmc_static_dense_e( sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp index 21bd6d711df..b56082620ad 100644 --- a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp @@ -98,10 +98,15 @@ int hmc_static_dense_e_adapt( logger); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_diag_e.hpp b/src/stan/services/sample/hmc_static_diag_e.hpp index 87ea955f842..b19c211047a 100644 --- a/src/stan/services/sample/hmc_static_diag_e.hpp +++ b/src/stan/services/sample/hmc_static_diag_e.hpp @@ -78,10 +78,14 @@ int hmc_static_diag_e(Model& model, const stan::io::var_context& init, sampler.set_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp index 88979dc3419..e4041b59b3d 100644 --- a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp @@ -96,10 +96,16 @@ int hmc_static_diag_e_adapt( logger); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_unit_e.hpp b/src/stan/services/sample/hmc_static_unit_e.hpp index d50c902479f..8e2b8428fbe 100644 --- a/src/stan/services/sample/hmc_static_unit_e.hpp +++ b/src/stan/services/sample/hmc_static_unit_e.hpp @@ -68,9 +68,14 @@ int hmc_static_unit_e(Model& model, const stan::io::var_context& init, sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp index fb0da9aff54..bf4bf5c17e5 100644 --- a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp @@ -80,10 +80,15 @@ int hmc_static_unit_e_adapt( sampler.get_stepsize_adaptation().set_t0(t0); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/standalone_gqs.hpp b/src/stan/services/sample/standalone_gqs.hpp index 378be61106d..c792acd08f1 100644 --- a/src/stan/services/sample/standalone_gqs.hpp +++ b/src/stan/services/sample/standalone_gqs.hpp @@ -67,19 +67,23 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws, std::vector unconstrained_params_r; std::vector row(draws.cols()); - - for (size_t i = 0; i < draws.rows(); ++i) { - Eigen::Map(&row[0], draws.cols()) = draws.row(i); - try { - model.unconstrain_array(row, unconstrained_params_r, &msg); - } catch (const std::exception &e) { - if (msg.str().length() > 0) - logger.error(msg); - logger.error(e.what()); - return error_codes::DATAERR; + try { + for (size_t i = 0; i < draws.rows(); ++i) { + Eigen::Map(&row[0], draws.cols()) = draws.row(i); + try { + model.unconstrain_array(row, unconstrained_params_r, &msg); + } catch (const std::exception &e) { + if (msg.str().length() > 0) + logger.error(msg); + logger.error(e.what()); + return error_codes::DATAERR; + } + interrupt(); // call out to interrupt and fail + writer.write_gq_values(model, rng, unconstrained_params_r); } - interrupt(); // call out to interrupt and fail - writer.write_gq_values(model, rng, unconstrained_params_r); + } catch (const std::exception &e) { + logger.error(e.what()); + return error_codes::SOFTWARE; } return error_codes::OK; } @@ -147,34 +151,40 @@ int standalone_generate(const Model &model, const int num_chains, rngs.emplace_back(util::create_rng(seed, i + 1)); } bool error_any = false; - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [&draws, &model, &logger, &interrupt, &writers, &rngs, - &error_any](const tbb::blocked_range &r) { - Eigen::VectorXd unconstrained_params_r(draws[0].cols()); - Eigen::VectorXd row(draws[0].cols()); - std::stringstream msg; - for (size_t slice_idx = r.begin(); slice_idx != r.end(); ++slice_idx) { - for (size_t i = 0; i < draws[slice_idx].rows(); ++i) { - if (error_any) - return; - try { - row = draws[slice_idx].row(i); - model.unconstrain_array(row, unconstrained_params_r, &msg); - } catch (const std::exception &e) { - if (msg.str().length() > 0) - logger.error(msg); - logger.error(e.what()); - error_any = true; - return; + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [&draws, &model, &logger, &interrupt, &writers, &rngs, + &error_any](const tbb::blocked_range &r) { + Eigen::VectorXd unconstrained_params_r(draws[0].cols()); + Eigen::VectorXd row(draws[0].cols()); + std::stringstream msg; + for (size_t slice_idx = r.begin(); slice_idx != r.end(); + ++slice_idx) { + for (size_t i = 0; i < draws[slice_idx].rows(); ++i) { + if (error_any) + return; + try { + row = draws[slice_idx].row(i); + model.unconstrain_array(row, unconstrained_params_r, &msg); + } catch (const std::domain_error &e) { + if (msg.str().length() > 0) + logger.error(msg); + logger.error(e.what()); + error_any = true; + return; + } + interrupt(); // call out to interrupt and fail + writers[slice_idx].write_gq_values(model, rngs[slice_idx], + unconstrained_params_r); } - interrupt(); // call out to interrupt and fail - writers[slice_idx].write_gq_values(model, rngs[slice_idx], - unconstrained_params_r); } - } - }, - tbb::simple_partitioner()); + }, + tbb::simple_partitioner()); + } catch (const std::exception &e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_any ? error_codes::DATAERR : error_codes::OK; } From 759a5692b86ac4b12154379dba73319dff6ebe6d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 25 Jan 2024 10:51:44 -0500 Subject: [PATCH 4/6] Fix pathfinder return --- src/stan/services/pathfinder/single.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index 4719a40e1e9..9b19f0be29c 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -821,10 +821,14 @@ inline auto pathfinder_lbfgs_single( lbfgs_ss.str(""); } if (ReturnLpSamples) { + // we want to terminate multi-path pathfinder during these unrecoverable + // exceptions throw; } else { logger.error(e.what()); - return error_codes::SOFTWARE; + return internal::ret_pathfinder( + error_codes::SOFTWARE, Eigen::Array(0), + Eigen::Matrix(0, 0), 0); } } } From 24ef95f1cd6344e2d641646eab357fbb6cf16b56 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 13 Feb 2024 12:16:55 -0500 Subject: [PATCH 5/6] Move try/catch outside optimization loop --- src/stan/services/optimize/bfgs.hpp | 110 +++++++++++++-------------- src/stan/services/optimize/lbfgs.hpp | 107 ++++++++++++-------------- 2 files changed, 102 insertions(+), 115 deletions(-) diff --git a/src/stan/services/optimize/bfgs.hpp b/src/stan/services/optimize/bfgs.hpp index bcb0e49f31b..f282de6f98b 100644 --- a/src/stan/services/optimize/bfgs.hpp +++ b/src/stan/services/optimize/bfgs.hpp @@ -114,74 +114,68 @@ int bfgs(Model& model, const stan::io::var_context& init, } int ret = 0; - while (ret == 0) { - interrupt(); - if (refresh > 0 - && (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0))) - logger.info( - " Iter" - " log prob" - " ||dx||" - " ||grad||" - " alpha" - " alpha0" - " # evals" - " Notes "); + try { + while (ret == 0) { + interrupt(); + if (refresh > 0 + && (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0))) + logger.info( + " Iter" + " log prob" + " ||dx||" + " ||grad||" + " alpha" + " alpha0" + " # evals" + " Notes "); - try { ret = bfgs.step(); - } catch (const std::exception& e) { - logger.error(e.what()); - return error_codes::SOFTWARE; - } - lp = bfgs.logp(); - bfgs.params_r(cont_vector); - - if (refresh > 0 - && (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0 - || ((bfgs.iter_num() + 1) % refresh == 0))) { - std::stringstream msg; - msg << " " << std::setw(7) << bfgs.iter_num() << " "; - msg << " " << std::setw(12) << std::setprecision(6) << lp << " "; - msg << " " << std::setw(12) << std::setprecision(6) - << bfgs.prev_step_size() << " "; - msg << " " << std::setw(12) << std::setprecision(6) - << bfgs.curr_g().norm() << " "; - msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha() - << " "; - msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0() - << " "; - msg << " " << std::setw(7) << bfgs.grad_evals() << " "; - msg << " " << bfgs.note() << " "; - logger.info(msg); - } + lp = bfgs.logp(); + bfgs.params_r(cont_vector); + + if (refresh > 0 + && (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0 + || ((bfgs.iter_num() + 1) % refresh == 0))) { + std::stringstream msg; + msg << " " << std::setw(7) << bfgs.iter_num() << " "; + msg << " " << std::setw(12) << std::setprecision(6) << lp << " "; + msg << " " << std::setw(12) << std::setprecision(6) + << bfgs.prev_step_size() << " "; + msg << " " << std::setw(12) << std::setprecision(6) + << bfgs.curr_g().norm() << " "; + msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha() + << " "; + msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0() + << " "; + msg << " " << std::setw(7) << bfgs.grad_evals() << " "; + msg << " " << bfgs.note() << " "; + logger.info(msg); + } - if (bfgs_ss.str().length() > 0) { - logger.info(bfgs_ss); - bfgs_ss.str(""); - } + if (bfgs_ss.str().length() > 0) { + logger.info(bfgs_ss); + bfgs_ss.str(""); + } - if (save_iterations) { - std::vector values; - std::stringstream msg; - try { + if (save_iterations) { + std::vector values; + std::stringstream msg; model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); - } catch (const std::exception& e) { - if (msg.str().length() > 0) { + + // This if is here to match the pre-refactor behavior + if (msg.str().length() > 0) logger.info(msg); - } - logger.error(e.what()); - return error_codes::SOFTWARE; - } - // This if is here to match the pre-refactor behavior - if (msg.str().length() > 0) - logger.info(msg); - values.insert(values.begin(), lp); - parameter_writer(values); + values.insert(values.begin(), lp); + parameter_writer(values); + } } + + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; } if (!save_iterations) { diff --git a/src/stan/services/optimize/lbfgs.hpp b/src/stan/services/optimize/lbfgs.hpp index 9045b5470e2..bf50e9088f6 100644 --- a/src/stan/services/optimize/lbfgs.hpp +++ b/src/stan/services/optimize/lbfgs.hpp @@ -109,72 +109,65 @@ int lbfgs(Model& model, const stan::io::var_context& init, } int ret = 0; - while (ret == 0) { - interrupt(); - if (refresh > 0 - && (lbfgs.iter_num() == 0 || ((lbfgs.iter_num() + 1) % refresh == 0))) - logger.info( - " Iter" - " log prob" - " ||dx||" - " ||grad||" - " alpha" - " alpha0" - " # evals" - " Notes "); + try { + while (ret == 0) { + interrupt(); + if (refresh > 0 + && (lbfgs.iter_num() == 0 || ((lbfgs.iter_num() + 1) % refresh == 0))) + logger.info( + " Iter" + " log prob" + " ||dx||" + " ||grad||" + " alpha" + " alpha0" + " # evals" + " Notes "); - try { ret = lbfgs.step(); - } catch (const std::exception& e) { - logger.error(e.what()); - return error_codes::SOFTWARE; - } - lp = lbfgs.logp(); - lbfgs.params_r(cont_vector); - - if (refresh > 0 - && (ret != 0 || !lbfgs.note().empty() || lbfgs.iter_num() == 0 - || ((lbfgs.iter_num() + 1) % refresh == 0))) { - std::stringstream msg; - msg << " " << std::setw(7) << lbfgs.iter_num() << " "; - msg << " " << std::setw(12) << std::setprecision(6) << lp << " "; - msg << " " << std::setw(12) << std::setprecision(6) - << lbfgs.prev_step_size() << " "; - msg << " " << std::setw(12) << std::setprecision(6) - << lbfgs.curr_g().norm() << " "; - msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha() - << " "; - msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha0() - << " "; - msg << " " << std::setw(7) << lbfgs.grad_evals() << " "; - msg << " " << lbfgs.note() << " "; - logger.info(msg); - } - if (lbfgs_ss.str().length() > 0) { - logger.info(lbfgs_ss); - lbfgs_ss.str(""); - } + lp = lbfgs.logp(); + lbfgs.params_r(cont_vector); + + if (refresh > 0 + && (ret != 0 || !lbfgs.note().empty() || lbfgs.iter_num() == 0 + || ((lbfgs.iter_num() + 1) % refresh == 0))) { + std::stringstream msg; + msg << " " << std::setw(7) << lbfgs.iter_num() << " "; + msg << " " << std::setw(12) << std::setprecision(6) << lp << " "; + msg << " " << std::setw(12) << std::setprecision(6) + << lbfgs.prev_step_size() << " "; + msg << " " << std::setw(12) << std::setprecision(6) + << lbfgs.curr_g().norm() << " "; + msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha() + << " "; + msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha0() + << " "; + msg << " " << std::setw(7) << lbfgs.grad_evals() << " "; + msg << " " << lbfgs.note() << " "; + logger.info(msg); + } - if (save_iterations) { - std::vector values; - std::stringstream msg; - try { + if (lbfgs_ss.str().length() > 0) { + logger.info(lbfgs_ss); + lbfgs_ss.str(""); + } + + if (save_iterations) { + std::vector values; + std::stringstream msg; model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); - } catch (const std::exception& e) { - if (msg.str().length() > 0) { + if (msg.str().length() > 0) logger.info(msg); - } - logger.error(e.what()); - return error_codes::SOFTWARE; - } - if (msg.str().length() > 0) - logger.info(msg); - values.insert(values.begin(), lp); - parameter_writer(values); + values.insert(values.begin(), lp); + parameter_writer(values); + } } + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; } if (!save_iterations) { From c95036b4ea48d711aae15bfbc36fc26e406d36f0 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 13 Feb 2024 12:20:17 -0500 Subject: [PATCH 6/6] cpplint fix --- src/stan/services/optimize/bfgs.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stan/services/optimize/bfgs.hpp b/src/stan/services/optimize/bfgs.hpp index f282de6f98b..37cd0a58973 100644 --- a/src/stan/services/optimize/bfgs.hpp +++ b/src/stan/services/optimize/bfgs.hpp @@ -172,7 +172,6 @@ int bfgs(Model& model, const stan::io::var_context& init, parameter_writer(values); } } - } catch (const std::exception& e) { logger.error(e.what()); return error_codes::SOFTWARE;