diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 0e5234b5..10fd2b25 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -680,11 +680,19 @@ class GaussianMultivariateRegressionSuffStat { void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { - XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all)/dataset.VarWeightValue(row_idx); - ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)))/dataset.VarWeightValue(row_idx); + for (int i = 0; i < p; i++) { + ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i) / dataset.VarWeightValue(row_idx); + for (int j = 0; j < p; j++) { + XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j) / dataset.VarWeightValue(row_idx); + } + } } else { - XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all); - ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all))); + for (int i = 0; i < p; i++) { + ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i); + for (int j = 0; j < p; j++) { + XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j); + } + } } } /*! @@ -692,8 +700,12 @@ class GaussianMultivariateRegressionSuffStat { */ void ResetSuffStat() { n = 0; - XtWX = Eigen::MatrixXd::Zero(p, p); - ytWX = Eigen::MatrixXd::Zero(1, p); + for (int i = 0; i < p; i++) { + ytWX(0, i) = 0.0; + for (int j = 0; j < p; j++) { + XtWX(i, j) = 0.0; + } + } } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs`