From 817ff540936ec4743e006d2ac7797659f45e3f35 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 22 Jan 2020 12:27:59 +0800 Subject: [PATCH 1/3] Move prediction cache into Learner. * Clean-ups - Remove duplicated cache in Learner and GBM. - Remove ad-hoc fix of invalid cache. - Remove `PredictFromCache` in predictors. - Remove prediction cache for linear altogether, as it's only moving the prediction into training process but doesn't provide any actual overall speed gain. - The cache is now unique to Learner, which means the ownership is no longer shared by any other components. * Changes - Add version to prediction cache. - Use weak ptr to check expired DMatrix. - Pass shared pointer instead of raw pointer. --- include/xgboost/gbm.h | 28 +++-- include/xgboost/predictor.h | 108 ++++++++++++------ include/xgboost/tree_model.h | 2 +- src/c_api/c_api.cc | 16 ++- src/gbm/gblinear.cc | 66 ++--------- src/gbm/gblinear_model.h | 3 +- src/gbm/gbm.cc | 7 +- src/gbm/gbtree.cc | 39 ++++--- src/gbm/gbtree.h | 27 ++--- src/learner.cc | 82 +++++++++----- src/predictor/cpu_predictor.cc | 130 +++++++++++---------- src/predictor/gpu_predictor.cu | 132 ++++++++++------------ src/predictor/predictor.cc | 44 +++++++- tests/cpp/gbm/test_gbtree.cc | 4 +- tests/cpp/helpers.cc | 31 +++-- tests/cpp/helpers.h | 13 ++- tests/cpp/predictor/test_cpu_predictor.cc | 23 ++-- tests/cpp/predictor/test_gpu_predictor.cu | 30 +++-- tests/cpp/predictor/test_predictor.cc | 33 ++++++ tests/python/test_basic.py | 22 ++-- 20 files changed, 462 insertions(+), 378 deletions(-) create mode 100644 tests/cpp/predictor/test_predictor.cc diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 90645371b2f1..0f61e1c9fe90 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors + * Copyright 2014-2020 by Contributors * \file gbm.h * \brief Interface of gradient booster, * that learns through gradient statistics. @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace xgboost { @@ -28,6 +29,8 @@ class ObjFunction; struct GenericParameter; struct LearnerModelParam; +struct PredictionCacheEntry; +class PredictionContainer; /*! * \brief interface of gradient boosting model. @@ -38,7 +41,7 @@ class GradientBooster : public Model, public Configurable { public: /*! \brief virtual destructor */ - virtual ~GradientBooster() = default; + ~GradientBooster() override = default; /*! * \brief Set the configuration of gradient boosting. * User must call configure once before InitModel and Training. @@ -71,19 +74,22 @@ class GradientBooster : public Model, public Configurable { * \param obj The objective function, optional, can be nullptr when use customized version * the booster may change content of gpair */ - virtual void DoBoost(DMatrix* p_fmat, - HostDeviceVector* in_gpair, - ObjFunction* obj = nullptr) = 0; + virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, + PredictionCacheEntry *prediction) = 0; /*! * \brief generate predictions for given feature matrix * \param dmat feature matrix * \param out_preds output vector to hold the predictions - * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means - * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear + * \param training Whether the prediction value is used for training. For dart booster + * drop out is performed during training. + * \param ntree_limit limit the number of trees used in prediction, + * when it equals 0, this means we do not limit + * number of trees, this parameter is only valid + * for gbtree, but not for gblinear */ virtual void PredictBatch(DMatrix* dmat, - HostDeviceVector* out_preds, + PredictionCacheEntry* out_preds, bool training, unsigned ntree_limit = 0) = 0; /*! @@ -158,8 +164,7 @@ class GradientBooster : public Model, public Configurable { static GradientBooster* Create( const std::string& name, GenericParameter const* generic_param, - LearnerModelParam const* learner_model_param, - const std::vector >& cache_mats); + LearnerModelParam const* learner_model_param); static void AssertGPUSupport() { #ifndef XGBOOST_USE_CUDA @@ -174,8 +179,7 @@ class GradientBooster : public Model, public Configurable { struct GradientBoosterReg : public dmlc::FunctionRegEntryBase< GradientBoosterReg, - std::function > &cached_mats, - LearnerModelParam const* learner_model_param)> > { + std::function > { }; /*! diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index b49a4899e33f..1b4991833a2d 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors + * Copyright 2017-2020 by Contributors * \file predictor.h * \brief Interface of predictor, * performs predictions for a gradient booster. @@ -32,47 +32,83 @@ namespace xgboost { * \brief Contains pointer to input matrix and associated cached predictions. */ struct PredictionCacheEntry { - std::shared_ptr data; + // A storage for caching prediction values HostDeviceVector predictions; + // The version of current cache, corresponding number of layers of trees + uint32_t version; + // A weak pointer for checking whether the DMatrix object has expired. + std::weak_ptr< DMatrix > ref; + + PredictionCacheEntry() : version { 0 } {} + /* \brief Update the cache entry by number of versions. + * + * \param v Added versions. + */ + void Update(uint32_t v) { + version += v; + } +}; + +/* \brief A container for managed prediction caches. + */ +class PredictionContainer { + std::unordered_map container_; + void ClearExpiredEntries(); + + public: + PredictionContainer() = default; + /* \brief Add a new DMatrix to the cache, at the same time this function will clear out + * all expired caches by checking the `std::weak_ptr`. Caching an existing + * DMatrix won't renew it. + * + * Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the + * entry this shared pointer is necessary. More importantly, the life time of this + * cache is tied to the shared pointer. + * + * Another way to make a safe cache is create a proxy to this entry, with anther shared + * pointer defined inside, and pass this proxy around instead of the real entry. But + * seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have + * (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`. + * + * \param m shared pointer to the DMatrix that needs to be cached. + * \param device Which device should the cache be allocated on. Pass + * GenericParameter::kCpuId for CPU or positive integer for GPU id. + * + * \return the cache entry for passed in DMatrix, either an existing cache or newly + * created. + */ + PredictionCacheEntry& Cache(std::shared_ptr m, int32_t device); + /* \brief Get a prediction cache entry. This entry must be already allocated by `Cache` + * method. Otherwise a dmlc::Error is thrown. + * + * \param m pointer to the DMatrix. + * \return The prediction cache for passed in DMatrix. + */ + PredictionCacheEntry& Entry(DMatrix* m); + /* \brief Get a const reference to the underlying hash map. Clear expired caches before + * returning. + */ + decltype(container_) const& Container(); }; /** * \class Predictor * - * \brief Performs prediction on individual training instances or batches of - * instances for GBTree. The predictor also manages a prediction cache - * associated with input matrices. If possible, it will use previously - * calculated predictions instead of calculating new predictions. - * Prediction functions all take a GBTreeModel and a DMatrix as input and - * output a vector of predictions. The predictor does not modify any state of - * the model itself. + * \brief Performs prediction on individual training instances or batches of instances for + * GBTree. Prediction functions all take a GBTreeModel and a DMatrix as input and + * output a vector of predictions. The predictor does not modify any state of the + * model itself. */ - class Predictor { protected: /* * \brief Runtime parameters. */ GenericParameter const* generic_param_; - /** - * \brief Map of matrices and associated cached predictions to facilitate - * storing and looking up predictions. - */ - std::shared_ptr> cache_; - - std::unordered_map::iterator FindCache(DMatrix const* dmat) { - auto cache_emtry = std::find_if( - cache_->begin(), cache_->end(), - [dmat](std::pair const &kv) { - return kv.second.data.get() == dmat; - }); - return cache_emtry; - } public: - Predictor(GenericParameter const* generic_param, - std::shared_ptr> cache) : - generic_param_{generic_param}, cache_{cache} {} + explicit Predictor(GenericParameter const* generic_param) : + generic_param_{generic_param} {} virtual ~Predictor() = default; /** @@ -91,12 +127,11 @@ class Predictor { * \param model The model to predict from. * \param tree_begin The tree begin index. * \param ntree_limit (Optional) The ntree limit. 0 means do not - * limit trees. + * limit trees. */ - - virtual void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, const gbm::GBTreeModel& model, int tree_begin, - unsigned ntree_limit = 0) = 0; + uint32_t const ntree_limit = 0) = 0; /** * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel @@ -116,7 +151,9 @@ class Predictor { virtual void UpdatePredictionCache( const gbm::GBTreeModel& model, std::vector>* updaters, - int num_new_trees) = 0; + int num_new_trees, + DMatrix* m, + PredictionCacheEntry* predts) = 0; /** * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst& @@ -200,8 +237,7 @@ class Predictor { * \param cache Pointer to prediction cache. */ static Predictor* Create( - std::string const& name, GenericParameter const* generic_param, - std::shared_ptr> cache); + std::string const& name, GenericParameter const* generic_param); }; /*! @@ -209,9 +245,7 @@ class Predictor { */ struct PredictorReg : public dmlc::FunctionRegEntryBase< - PredictorReg, std::function>)>> {}; + PredictorReg, std::function> {}; #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 3bcb1cc3be53..69f3a718b662 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -158,7 +158,7 @@ class RegTree : public Model { } /*! \brief whether this node is deleted */ XGBOOST_DEVICE bool IsDeleted() const { - return sindex_ == std::numeric_limits::max(); + return sindex_ == std::numeric_limits::max(); } /*! \brief whether current node is root */ XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ccc49d91a402..849db1ab22fe 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -11,8 +11,8 @@ #include #include - #include "xgboost/data.h" +#include "xgboost/host_device_vector.h" #include "xgboost/learner.h" #include "xgboost/c_api.h" #include "xgboost/logging.h" @@ -146,7 +146,7 @@ struct XGBAPIThreadLocalEntry { /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; /*! \brief returning float vector. */ - std::vector ret_vec_float; + HostDeviceVector ret_vec_float; /*! \brief temp variable of gradient pairs. */ std::vector tmp_gpair; }; @@ -553,24 +553,22 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, int32_t training, xgboost::bst_ulong *len, const bst_float **out_result) { - std::vector& preds = - XGBAPIThreadLocalStore::Get()->ret_vec_float; API_BEGIN(); CHECK_HANDLE(); + HostDeviceVector& preds = + XGBAPIThreadLocalStore::Get()->ret_vec_float; auto *bst = static_cast(handle); - HostDeviceVector tmp_preds; bst->Predict( *static_cast*>(dmat), (option_mask & 1) != 0, - &tmp_preds, ntree_limit, + &preds, ntree_limit, static_cast(training), (option_mask & 2) != 0, (option_mask & 4) != 0, (option_mask & 8) != 0, (option_mask & 16) != 0); - preds = tmp_preds.HostVector(); - *out_result = dmlc::BeginPtr(preds); - *len = static_cast(preds.size()); + *out_result = dmlc::BeginPtr(preds.HostVector()); + *len = static_cast(preds.Size()); API_END(); } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 9724e37b5bb9..3fff06d992bb 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -15,6 +15,7 @@ #include "xgboost/gbm.h" #include "xgboost/json.h" +#include "xgboost/predictor.h" #include "xgboost/linear_updater.h" #include "xgboost/logging.h" #include "xgboost/learner.h" @@ -50,21 +51,14 @@ struct GBLinearTrainParam : public XGBoostParameter { */ class GBLinear : public GradientBooster { public: - explicit GBLinear(const std::vector > &cache, - LearnerModelParam const* learner_model_param) + explicit GBLinear(LearnerModelParam const* learner_model_param) : learner_model_param_{learner_model_param}, model_{learner_model_param_}, previous_model_{learner_model_param_}, sum_instance_weight_(0), sum_weight_complete_(false), - is_converged_(false) { - // Add matrices to the prediction cache - for (auto &d : cache) { - PredictionCacheEntry e; - e.data = d; - cache_[d.get()] = std::move(e); - } - } + is_converged_(false) {} + void Configure(const Args& cfg) override { if (model_.weight.size() == 0) { model_.Configure(cfg); @@ -118,7 +112,7 @@ class GBLinear : public GradientBooster { void DoBoost(DMatrix *p_fmat, HostDeviceVector *in_gpair, - ObjFunction* obj) override { + PredictionCacheEntry* predt) override { monitor_.Start("DoBoost"); model_.LazyInitModel(); @@ -127,28 +121,19 @@ class GBLinear : public GradientBooster { if (!this->CheckConvergence()) { updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_); } - this->UpdatePredictionCache(); monitor_.Stop("DoBoost"); } void PredictBatch(DMatrix *p_fmat, - HostDeviceVector *out_preds, + PredictionCacheEntry *predts, bool training, unsigned ntree_limit) override { monitor_.Start("PredictBatch"); + auto* out_preds = &predts->predictions; CHECK_EQ(ntree_limit, 0U) << "GBLinear::Predict ntrees is only valid for gbtree predictor"; - - // Try to predict from cache - auto it = cache_.find(p_fmat); - if (it != cache_.end() && it->second.predictions.size() != 0) { - std::vector &y = it->second.predictions; - out_preds->Resize(y.size()); - std::copy(y.begin(), y.end(), out_preds->HostVector().begin()); - } else { - this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); - } + this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); monitor_.Stop("PredictBatch"); } // add base margin @@ -258,7 +243,8 @@ class GBLinear : public GradientBooster { const size_t ridx = batch.base_rowid + i; // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float margin = (base_margin.size() != 0) ? + bst_float margin = + (base_margin.size() != 0) ? base_margin[ridx * ngroup + gid] : learner_model_param_->base_score; this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); } @@ -266,17 +252,6 @@ class GBLinear : public GradientBooster { } monitor_.Stop("PredictBatchInternal"); } - void UpdatePredictionCache() { - // update cache entry - for (auto &kv : cache_) { - PredictionCacheEntry &e = kv.second; - if (e.predictions.size() == 0) { - size_t n = model_.learner_model_param_->num_output_group * e.data->Info().num_row_; - e.predictions.resize(n); - } - this->PredictBatchInternal(e.data.get(), &e.predictions); - } - } bool CheckConvergence() { if (param_.tolerance == 0.0f) return false; @@ -327,22 +302,6 @@ class GBLinear : public GradientBooster { bool sum_weight_complete_; common::Monitor monitor_; bool is_converged_; - - /** - * \struct PredictionCacheEntry - * - * \brief Contains pointer to input matrix and associated cached predictions. - */ - struct PredictionCacheEntry { - std::shared_ptr data; - std::vector predictions; - }; - - /** - * \brief Map of matrices and associated cached predictions to facilitate - * storing and looking up predictions. - */ - std::unordered_map cache_; }; // register the objective functions @@ -350,9 +309,8 @@ DMLC_REGISTER_PARAMETER(GBLinearTrainParam); XGBOOST_REGISTER_GBM(GBLinear, "gblinear") .describe("Linear booster, implement generalized linear model.") - .set_body([](const std::vector > &cache, - LearnerModelParam const* booster_config) { - return new GBLinear(cache, booster_config); + .set_body([](LearnerModelParam const* booster_config) { + return new GBLinear(booster_config); }); } // namespace gbm } // namespace xgboost diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h index 71b8bcd061c2..1564ffedda8f 100644 --- a/src/gbm/gblinear_model.h +++ b/src/gbm/gblinear_model.h @@ -55,8 +55,9 @@ class GBLinearModel : public Model { std::vector weight; // initialize the model parameter inline void LazyInitModel() { - if (!weight.empty()) + if (!weight.empty()) { return; + } // bias is the last weight weight.resize((learner_model_param_->num_feature + 1) * learner_model_param_->num_output_group); diff --git a/src/gbm/gbm.cc b/src/gbm/gbm.cc index 2d6a89556477..87a6ded29042 100644 --- a/src/gbm/gbm.cc +++ b/src/gbm/gbm.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2020 by Contributors * \file gbm.cc * \brief Registry of gradient boosters. */ @@ -20,13 +20,12 @@ namespace xgboost { GradientBooster* GradientBooster::Create( const std::string& name, GenericParameter const* generic_param, - LearnerModelParam const* learner_model_param, - const std::vector >& cache_mats) { + LearnerModelParam const* learner_model_param) { auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown gbm type " << name; } - auto p_bst = (e->body)(cache_mats, learner_model_param); + auto p_bst = (e->body)(learner_model_param); p_bst->generic_param_ = generic_param; return p_bst; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index abd2b9fa8b82..c318d860b446 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2020 by Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -14,6 +14,7 @@ #include #include +#include "xgboost/data.h" #include "xgboost/gbm.h" #include "xgboost/logging.h" #include "xgboost/json.h" @@ -47,14 +48,14 @@ void GBTree::Configure(const Args& cfg) { // configure predictors if (!cpu_predictor_) { cpu_predictor_ = std::unique_ptr( - Predictor::Create("cpu_predictor", this->generic_param_, cache_)); + Predictor::Create("cpu_predictor", this->generic_param_)); } cpu_predictor_->Configure(cfg); #if defined(XGBOOST_USE_CUDA) auto n_gpus = common::AllVisibleGPUs(); if (!gpu_predictor_ && n_gpus != 0) { gpu_predictor_ = std::unique_ptr( - Predictor::Create("gpu_predictor", this->generic_param_, cache_)); + Predictor::Create("gpu_predictor", this->generic_param_)); } if (n_gpus != 0) { gpu_predictor_->Configure(cfg); @@ -183,7 +184,7 @@ void GBTree::ConfigureUpdaters() { void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, - ObjFunction* obj) { + PredictionCacheEntry* predt) { std::vector > > new_trees; const int ngroup = model_.learner_model_param_->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); @@ -195,7 +196,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, new_trees.push_back(std::move(ret)); } else { CHECK_EQ(in_gpair->Size() % ngroup, 0U) - << "must have exactly ngroup*nrow gpairs"; + << "must have exactly ngroup * nrow gpairs"; // TODO(canonizer): perform this on GPU if HostDeviceVector has device set. HostDeviceVector tmp(in_gpair->Size() / ngroup, GradientPair(), @@ -214,7 +215,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, } } monitor_.Stop("BoostNewTrees"); - this->CommitModel(std::move(new_trees)); + this->CommitModel(std::move(new_trees), p_fmat, predt); } void GBTree::InitUpdater(Args const& cfg) { @@ -286,7 +287,9 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, } } -void GBTree::CommitModel(std::vector>>&& new_trees) { +void GBTree::CommitModel(std::vector>>&& new_trees, + DMatrix* m, + PredictionCacheEntry* predts) { monitor_.Start("CommitModel"); int num_new_trees = 0; for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) { @@ -294,7 +297,7 @@ void GBTree::CommitModel(std::vector>>&& ne model_.CommitModel(std::move(new_trees[gid]), gid); } CHECK(configured_); - GetPredictor()->UpdatePredictionCache(model_, &updaters_, num_new_trees); + GetPredictor()->UpdatePredictionCache(model_, &updaters_, num_new_trees, m, predts); monitor_.Stop("CommitModel"); } @@ -303,13 +306,16 @@ void GBTree::LoadConfig(Json const& in) { fromJson(in["gbtree_train_param"], &tparam_); int32_t const n_gpus = xgboost::common::AllVisibleGPUs(); if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) { + LOG(WARNING) + << "Loading from a raw memory buffer on CPU only machine. " + "Changing predictor to auto."; tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}}); } if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) { tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}}); LOG(WARNING) << "Loading from a raw memory buffer on CPU only machine. " - "Change tree_method to hist."; + "Changing tree_method to hist."; } auto const& j_updaters = get(in["updater"]); @@ -415,7 +421,7 @@ class Dart : public GBTree { } void PredictBatch(DMatrix* p_fmat, - HostDeviceVector* p_out_preds, + PredictionCacheEntry* p_out_preds, bool training, unsigned ntree_limit) override { DropTrees(training); @@ -426,7 +432,7 @@ class Dart : public GBTree { } size_t n = num_group * p_fmat->Info().num_row_; const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); - auto& out_preds = p_out_preds->HostVector(); + auto& out_preds = p_out_preds->predictions.HostVector(); out_preds.resize(n); if (base_margin.size() != 0) { CHECK_EQ(out_preds.size(), n); @@ -539,7 +545,9 @@ class Dart : public GBTree { // commit new trees all at once void - CommitModel(std::vector>>&& new_trees) override { + CommitModel(std::vector>>&& new_trees, + DMatrix* m, + PredictionCacheEntry* predts) override { int num_new_trees = 0; for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) { num_new_trees += new_trees[gid].size(); @@ -681,16 +689,13 @@ DMLC_REGISTER_PARAMETER(DartTrainParam); XGBOOST_REGISTER_GBM(GBTree, "gbtree") .describe("Tree booster, gradient boosted trees.") -.set_body([](const std::vector >& cached_mats, - LearnerModelParam const* booster_config) { +.set_body([](LearnerModelParam const* booster_config) { auto* p = new GBTree(booster_config); - p->InitCache(cached_mats); return p; }); XGBOOST_REGISTER_GBM(Dart, "dart") .describe("Tree booster, dart.") -.set_body([](const std::vector >& cached_mats, - LearnerModelParam const* booster_config) { +.set_body([](LearnerModelParam const* booster_config) { GBTree* p = new Dart(booster_config); return p; }); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index fb546969d5b2..f05d084d2856 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -16,6 +16,7 @@ #include #include +#include "xgboost/data.h" #include "xgboost/logging.h" #include "xgboost/gbm.h" #include "xgboost/predictor.h" @@ -151,14 +152,8 @@ struct DartTrainParam : public XGBoostParameter { // gradient boosted trees class GBTree : public GradientBooster { public: - explicit GBTree(LearnerModelParam const* booster_config) : model_(booster_config) {} - - void InitCache(const std::vector > &cache) { - cache_ = std::make_shared>(); - for (std::shared_ptr const& d : cache) { - (*cache_)[d.get()].data = d; - } - } + explicit GBTree(LearnerModelParam const* booster_config) : + model_(booster_config) {} void Configure(const Args& cfg) override; // Revise `tree_method` and `updater` parameters after seeing the training @@ -171,7 +166,7 @@ class GBTree : public GradientBooster { /*! \brief Carry out one iteration of boosting */ void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, - ObjFunction* obj) override; + PredictionCacheEntry* predt) override; bool UseGPU() const override { return @@ -204,11 +199,12 @@ class GBTree : public GradientBooster { } void PredictBatch(DMatrix* p_fmat, - HostDeviceVector* out_preds, + PredictionCacheEntry* out_preds, bool training, unsigned ntree_limit) override { CHECK(configured_); - GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); + GetPredictor(&out_preds->predictions, p_fmat)->PredictBatch( + p_fmat, out_preds, model_, 0, ntree_limit); } void PredictInstance(const SparsePage::Inst& inst, @@ -318,7 +314,9 @@ class GBTree : public GradientBooster { } // commit new trees all at once - virtual void CommitModel(std::vector>>&& new_trees); + virtual void CommitModel(std::vector>>&& new_trees, + DMatrix* m, + PredictionCacheEntry* predts); // --- data structure --- GBTreeModel model_; @@ -332,11 +330,6 @@ class GBTree : public GradientBooster { Args cfg_; // the updaters that can be applied to each of tree std::vector> updaters_; - /** - * \brief Map of matrices and associated cached predictions to facilitate - * storing and looking up predictions. - */ - std::shared_ptr> cache_; // Predictors std::unique_ptr cpu_predictor_; #if defined(XGBOOST_USE_CUDA) diff --git a/src/learner.cc b/src/learner.cc index 10b7882c6900..4b3eeedafbbd 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -17,6 +18,8 @@ #include #include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/predictor.h" #include "xgboost/feature_map.h" #include "xgboost/gbm.h" #include "xgboost/generic_parameters.h" @@ -187,9 +190,12 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) { */ class LearnerImpl : public Learner { public: - explicit LearnerImpl(std::vector > cache) - : need_configuration_{true}, cache_(std::move(cache)) { + explicit LearnerImpl(std::vector > cache) + : need_configuration_{true} { monitor_.Init("Learner"); + for (std::shared_ptr const& d : cache) { + cache_.Cache(d, GenericParameter::kCpuId); + } } // Configuration before data is known. void Configure() override { @@ -333,8 +339,7 @@ class LearnerImpl : public Learner { name = get(gradient_booster["name"]); tparam_.UpdateAllowUnknown(Args{{"booster", name}}); gbm_.reset(GradientBooster::Create(tparam_.booster, - &generic_parameters_, &learner_model_param_, - cache_)); + &generic_parameters_, &learner_model_param_)); gbm_->LoadModel(gradient_booster); learner_model_param_ = LearnerModelParam(mparam_, @@ -391,8 +396,7 @@ class LearnerImpl : public Learner { tparam_.booster = get(gradient_booster["name"]); if (!gbm_) { gbm_.reset(GradientBooster::Create(tparam_.booster, - &generic_parameters_, &learner_model_param_, - cache_)); + &generic_parameters_, &learner_model_param_)); } gbm_->LoadConfig(gradient_booster); @@ -492,7 +496,7 @@ class LearnerImpl : public Learner { // duplicated code with LazyInitModel obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, - &learner_model_param_, cache_)); + &learner_model_param_)); gbm_->Load(fi); if (mparam_.contain_extra_attrs != 0) { std::vector > attr; @@ -701,17 +705,18 @@ class LearnerImpl : public Learner { this->CheckDataSplitMode(); this->ValidateDMatrix(train.get()); + auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id); + monitor_.Start("PredictRaw"); - this->PredictRaw(train.get(), &preds_[train.get()], true); + this->PredictRaw(train.get(), &predt, true); monitor_.Stop("PredictRaw"); - TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions"); monitor_.Start("GetGradient"); - obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_); + obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); TrainingObserver::Instance().Observe(gpair_, "Gradients"); - gbm_->DoBoost(train.get(), &gpair_, obj_.get()); + gbm_->DoBoost(train.get(), &gpair_, &predt); monitor_.Stop("UpdateOneIter"); } @@ -724,12 +729,14 @@ class LearnerImpl : public Learner { } this->CheckDataSplitMode(); this->ValidateDMatrix(train.get()); + this->cache_.Cache(train, generic_parameters_.gpu_id); - gbm_->DoBoost(train.get(), in_gpair); + gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get())); monitor_.Stop("BoostOneIter"); } - std::string EvalOneIter(int iter, const std::vector>& data_sets, + std::string EvalOneIter(int iter, + const std::vector>& data_sets, const std::vector& data_names) override { monitor_.Start("EvalOneIter"); this->Configure(); @@ -741,13 +748,19 @@ class LearnerImpl : public Learner { metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); } for (size_t i = 0; i < data_sets.size(); ++i) { - DMatrix * dmat = data_sets[i].get(); - this->ValidateDMatrix(dmat); - this->PredictRaw(dmat, &preds_[dmat], false); - obj_->EvalTransform(&preds_[dmat]); + std::shared_ptr dmat = data_sets[i]; + auto& predt = this->cache_.Cache(dmat, generic_parameters_.gpu_id); + this->ValidateDMatrix(dmat.get()); + this->PredictRaw(dmat.get(), &predt, false); + + predictions_.SetDevice(generic_parameters_.gpu_id); + predictions_.Resize(predt.predictions.Size()); + predictions_.Copy(predt.predictions); + + obj_->EvalTransform(&predictions_); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(preds_[dmat], data_sets[i]->Info(), + << ev->Eval(predictions_, data_sets[i]->Info(), tparam_.dsplit == DataSplitMode::kRow); } } @@ -823,7 +836,12 @@ class LearnerImpl : public Learner { } else if (pred_leaf) { gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit); } else { - this->PredictRaw(data.get(), out_preds, training, ntree_limit); + auto& prediction = cache_.Cache(data, generic_parameters_.gpu_id); + this->PredictRaw(data.get(), &prediction, training, ntree_limit); + // Copy the prediction cache to output prediction. out_preds comes from C API + out_preds->SetDevice(generic_parameters_.gpu_id); + out_preds->Resize(prediction.predictions.Size()); + out_preds->Copy(prediction.predictions); if (!output_margin) { obj_->PredTransform(out_preds); } @@ -843,11 +861,10 @@ class LearnerImpl : public Learner { * predictor, when it equals 0, this means we are using all the trees * \param training allow dropout when the DART booster is being used */ - void PredictRaw(DMatrix* data, HostDeviceVector* out_preds, + void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds, bool training, unsigned ntree_limit = 0) const { - CHECK(gbm_ != nullptr) - << "Predict must happen after Load or configuration"; + CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->ValidateDMatrix(data); gbm_->PredictBatch(data, out_preds, training, ntree_limit); } @@ -894,7 +911,7 @@ class LearnerImpl : public Learner { void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { if (gbm_ == nullptr || old.booster != tparam_.booster) { gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, - &learner_model_param_, cache_)); + &learner_model_param_)); } gbm_->Configure(args); } @@ -904,9 +921,10 @@ class LearnerImpl : public Learner { // estimate feature bound // TODO(hcho3): Change num_feature to 64-bit integer unsigned num_feature = 0; - for (auto & matrix : cache_) { - CHECK(matrix != nullptr); - const uint64_t num_col = matrix->Info().num_col_; + for (auto & matrix : cache_.Container()) { + CHECK(matrix.first); + CHECK(!matrix.second.ref.expired()); + const uint64_t num_col = matrix.first->Info().num_col_; CHECK_LE(num_col, static_cast(std::numeric_limits::max())) << "Unfortunately, XGBoost does not support data matrices with " << std::numeric_limits::max() << " features or greater"; @@ -965,13 +983,12 @@ class LearnerImpl : public Learner { // `enable_experimental_json_serialization' is set to false. Will be removed once JSON // takes over. std::string const serialisation_header_ { u8"CONFIG-offset:" }; - // configurations + // User provided configurations std::map cfg_; + // Stores information like best-iteration for early stopping. std::map attributes_; std::vector metric_names_; static std::string const kEvalMetric; // NOLINT - // temporal storages for prediction - std::map> preds_; // gradient pairs HostDeviceVector gpair_; bool need_configuration_; @@ -979,8 +996,11 @@ class LearnerImpl : public Learner { private: /*! \brief random number transformation seed. */ static int32_t constexpr kRandSeedMagic = 127; - // internal cached dmatrix - std::vector > cache_; + // internal cached dmatrix for prediction. + PredictionContainer cache_; + /*! \brief Temporary storage to prediction. Useful for storing data transformed by + * objective function */ + HostDeviceVector predictions_; common::Monitor monitor_; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index d11822782fd4..74c955612c2f 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors 2017-2019 + * Copyright by Contributors 2017-2020 */ #include @@ -46,9 +46,9 @@ class CPUPredictor : public Predictor { } } - void PredLoopInternal(DMatrix* p_fmat, std::vector* out_preds, - gbm::GBTreeModel const& model, int32_t tree_begin, - int32_t tree_end) { + void PredInternal(DMatrix *p_fmat, std::vector *out_preds, + gbm::GBTreeModel const &model, int32_t tree_begin, + int32_t tree_end) { int32_t const num_group = model.learner_model_param_->num_output_group; const int nthread = omp_get_max_threads(); InitThreadTemp(nthread, model.learner_model_param_->num_feature); @@ -102,27 +102,6 @@ class CPUPredictor : public Predictor { } } - bool PredictFromCache(DMatrix* dmat, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, - unsigned ntree_limit) const { - CHECK(cache_); - if (ntree_limit == 0 || - ntree_limit * model.learner_model_param_->num_output_group >= model.trees.size()) { - auto it = cache_->find(dmat); - if (it != cache_->end()) { - const HostDeviceVector& y = it->second.predictions; - if (y.Size() != 0) { - out_preds->Resize(y.Size()); - std::copy(y.HostVector().begin(), y.HostVector().end(), - out_preds->HostVector().begin()); - return true; - } - } - } - return false; - } - void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { @@ -156,60 +135,78 @@ class CPUPredictor : public Predictor { } public: - CPUPredictor(GenericParameter const* generic_param, - std::shared_ptr> cache) : - Predictor::Predictor{generic_param, cache} {} - void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + explicit CPUPredictor(GenericParameter const* generic_param) : + Predictor::Predictor{generic_param} {} + // ntree_limit is a very problematic parameter, as it's ambiguous in the context of + // multi-output and forest. Same problem exists for tree_begin + void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, int tree_begin, - unsigned ntree_limit = 0) override { - if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { - return; + uint32_t const ntree_limit = 0) override { + // tree_begin is not used, right now we just enforce it to be 0. + CHECK_EQ(tree_begin, 0); + auto* out_preds = &predts->predictions; + CHECK_GE(predts->version, tree_begin); + if (predts->version == 0) { + CHECK_EQ(out_preds->Size(), 0); + this->InitOutPredictions(dmat->Info(), out_preds, model); } - this->InitOutPredictions(dmat->Info(), out_preds, model); - ntree_limit *= model.learner_model_param_->num_output_group; - if (ntree_limit == 0 || ntree_limit > model.trees.size()) { - ntree_limit = static_cast(model.trees.size()); + uint32_t const output_groups = model.learner_model_param_->num_output_group; + CHECK_NE(output_groups, 0); + // Right now we just assume ntree_limit provided by users means number of tree layers + // in the context of multi-output model + uint32_t real_ntree_limit = ntree_limit * output_groups; + if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { + real_ntree_limit = static_cast(model.trees.size()); } - this->PredLoopInternal(dmat, &out_preds->HostVector(), model, - tree_begin, ntree_limit); - - auto cache_entry = this->FindCache(dmat); - if (cache_entry == cache_->cend()) { - return; + uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; + // When users have provided ntree_limit, end_version can be lesser, cache is violated + if (predts->version > end_version) { + CHECK_NE(ntree_limit, 0); + this->InitOutPredictions(dmat->Info(), out_preds, model); + predts->version = 0; } - if (cache_entry->second.predictions.Size() == 0) { - // See comment in GPUPredictor::PredictBatch. - InitOutPredictions(cache_entry->second.data->Info(), - &(cache_entry->second.predictions), model); - cache_entry->second.predictions.Copy(*out_preds); + uint32_t const beg_version = predts->version; + CHECK_LE(beg_version, end_version); + + if (beg_version < end_version) { + this->PredInternal(dmat, &out_preds->HostVector(), model, + beg_version * output_groups, + end_version * output_groups); } + + // delta means {size of forest} * {number of newly accumulated layers} + uint32_t delta = end_version - beg_version; + CHECK_LE(delta, model.trees.size()); + predts->Update(delta); + + CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || + out_preds->Size() == dmat->Info().num_row_); } void UpdatePredictionCache( const gbm::GBTreeModel& model, std::vector>* updaters, - int num_new_trees) override { + int num_new_trees, + DMatrix* m, + PredictionCacheEntry* predts) override { int old_ntree = model.trees.size() - num_new_trees; // update cache entry - for (auto& kv : (*cache_)) { - PredictionCacheEntry& e = kv.second; - - if (e.predictions.Size() == 0) { - InitOutPredictions(e.data->Info(), &(e.predictions), model); - PredLoopInternal(e.data.get(), &(e.predictions.HostVector()), model, 0, - model.trees.size()); - } else if (model.learner_model_param_->num_output_group == 1 && updaters->size() > 0 && - num_new_trees == 1 && - updaters->back()->UpdatePredictionCache(e.data.get(), - &(e.predictions))) { - {} // do nothing - } else { - PredLoopInternal(e.data.get(), &(e.predictions.HostVector()), model, old_ntree, - model.trees.size()); - } + auto* out = &predts->predictions; + if (predts->predictions.Size() == 0) { + this->InitOutPredictions(m->Info(), out, model); + this->PredInternal(m, &out->HostVector(), model, 0, model.trees.size()); + } else if (model.learner_model_param_->num_output_group == 1 && + updaters->size() > 0 && + num_new_trees == 1 && + updaters->back()->UpdatePredictionCache(m, out)) { + {} + } else { + PredInternal(m, &out->HostVector(), model, old_ntree, model.trees.size()); } + auto delta = num_new_trees / model.learner_model_param_->num_output_group; + predts->Update(delta); } void PredictInstance(const SparsePage::Inst& inst, @@ -387,9 +384,8 @@ class CPUPredictor : public Predictor { XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") .describe("Make predictions using CPU.") -.set_body([](GenericParameter const* generic_param, - std::shared_ptr> cache) { - return new CPUPredictor(generic_param, cache); +.set_body([](GenericParameter const* generic_param) { + return new CPUPredictor(generic_param); }); } // namespace predictor } // namespace xgboost diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 1ff8d796db99..eac4bda96d09 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2018 by Contributors + * Copyright 2017-2020 by Contributors */ #include #include @@ -295,9 +295,8 @@ class GPUPredictor : public xgboost::Predictor { } public: - GPUPredictor(GenericParameter const* generic_param, - std::shared_ptr> cache) : - Predictor::Predictor{generic_param, cache} {} + explicit GPUPredictor(GenericParameter const* generic_param) : + Predictor::Predictor{generic_param} {} ~GPUPredictor() override { if (generic_param_->gpu_id >= 0) { @@ -305,43 +304,53 @@ class GPUPredictor : public xgboost::Predictor { } } - void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { + // This function is duplicated with CPU predictor PredictBatch, see comments in there. + // FIXME(trivialfis): Remove the duplication. int device = generic_param_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; ConfigureDevice(device); - if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { - return; + CHECK_EQ(tree_begin, 0); + auto* out_preds = &predts->predictions; + CHECK_GE(predts->version, tree_begin); + if (predts->version == 0) { + CHECK_EQ(out_preds->Size(), 0); + this->InitOutPredictions(dmat->Info(), out_preds, model); + } + + uint32_t const output_groups = model.learner_model_param_->num_output_group; + CHECK_NE(output_groups, 0); + + uint32_t real_ntree_limit = ntree_limit * output_groups; + if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { + real_ntree_limit = static_cast(model.trees.size()); } - this->InitOutPredictions(dmat->Info(), out_preds, model); - int32_t tree_end = ntree_limit * model.learner_model_param_->num_output_group; + uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; - if (ntree_limit == 0 || ntree_limit > model.trees.size()) { - tree_end = static_cast(model.trees.size()); + if (predts->version > end_version) { + CHECK_NE(ntree_limit, 0); + this->InitOutPredictions(dmat->Info(), out_preds, model); + predts->version = 0; } + uint32_t const beg_version = predts->version; + CHECK_LE(beg_version, end_version); - DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); - - auto cache_emtry = this->FindCache(dmat); - if (cache_emtry == cache_->cend()) { return; } - if (cache_emtry->second.predictions.Size() == 0) { - // Initialise the cache on first iteration, this comes useful - // when performing training continuation: - // - // 1. PredictBatch - // 2. CommitModel - // - updater->UpdatePredictionCache - // - // If we don't initialise this cache, the 2 step will recieve an invalid cache as - // the first step only modifies prediction store in learner without following code. - InitOutPredictions(cache_emtry->second.data->Info(), - &(cache_emtry->second.predictions), model); - CHECK_EQ(cache_emtry->second.predictions.Size(), out_preds->Size()); - cache_emtry->second.predictions.Copy(*out_preds); + if (beg_version < end_version) { + this->DevicePredictInternal(dmat, out_preds, model, + beg_version * output_groups, + end_version * output_groups); } + + uint32_t delta = end_version - beg_version; + CHECK_LE(delta, model.trees.size()); + predts->Update(delta); + + CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || + out_preds->Size() == dmat->Info().num_row_); } protected: @@ -361,49 +370,30 @@ class GPUPredictor : public xgboost::Predictor { } } - bool PredictFromCache(DMatrix* dmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) { - if (ntree_limit == 0 || - ntree_limit * model.learner_model_param_->num_output_group >= model.trees.size()) { - auto it = (*cache_).find(dmat); - if (it != cache_->cend()) { - const HostDeviceVector& y = it->second.predictions; - if (y.Size() != 0) { - monitor_.StartCuda("PredictFromCache"); - out_preds->SetDevice(y.DeviceIdx()); - out_preds->Resize(y.Size()); - out_preds->Copy(y); - monitor_.StopCuda("PredictFromCache"); - return true; - } - } - } - return false; - } - void UpdatePredictionCache( const gbm::GBTreeModel& model, std::vector>* updaters, - int num_new_trees) override { + int num_new_trees, + DMatrix* m, + PredictionCacheEntry* predts) override { + int device = generic_param_->gpu_id; + ConfigureDevice(device); auto old_ntree = model.trees.size() - num_new_trees; // update cache entry - for (auto& kv : (*cache_)) { - PredictionCacheEntry& e = kv.second; - DMatrix* dmat = kv.first; - HostDeviceVector& predictions = e.predictions; - - if (predictions.Size() == 0) { - this->InitOutPredictions(dmat->Info(), &predictions, model); - } - - if (model.learner_model_param_->num_output_group == 1 && updaters->size() > 0 && - num_new_trees == 1 && - updaters->back()->UpdatePredictionCache(e.data.get(), &predictions)) { - // do nothing - } else { - DevicePredictInternal(dmat, &predictions, model, old_ntree, model.trees.size()); - } + auto* out = &predts->predictions; + if (predts->predictions.Size() == 0) { + InitOutPredictions(m->Info(), out, model); + DevicePredictInternal(m, out, model, 0, model.trees.size()); + } else if (model.learner_model_param_->num_output_group == 1 && + updaters->size() > 0 && + num_new_trees == 1 && + updaters->back()->UpdatePredictionCache(m, out)) { + {} + } else { + DevicePredictInternal(m, out, model, old_ntree, model.trees.size()); } + auto delta = num_new_trees / model.learner_model_param_->num_output_group; + predts->Update(delta); } void PredictInstance(const SparsePage::Inst& inst, @@ -442,11 +432,6 @@ class GPUPredictor : public xgboost::Predictor { void Configure(const std::vector>& cfg) override { Predictor::Configure(cfg); - - int device = generic_param_->gpu_id; - if (device >= 0) { - ConfigureDevice(device); - } } private: @@ -469,9 +454,8 @@ class GPUPredictor : public xgboost::Predictor { XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") .describe("Make predictions using GPU.") -.set_body([](GenericParameter const* generic_param, - std::shared_ptr> cache) { - return new GPUPredictor(generic_param, cache); +.set_body([](GenericParameter const* generic_param) { + return new GPUPredictor(generic_param); }); } // namespace predictor diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 115caa8396e9..b20edf880f6c 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -1,24 +1,60 @@ /*! - * Copyright by Contributors 2017 + * Copyright 2017-2020 by Contributors */ #include #include +#include "xgboost/data.h" +#include "xgboost/generic_parameters.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc namespace xgboost { + +void PredictionContainer::ClearExpiredEntries() { + std::vector expired; + for (auto& kv : container_) { + if (kv.second.ref.expired()) { + expired.emplace_back(kv.first); + } + } + for (auto const& ptr : expired) { + container_.erase(ptr); + } +} + +PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr m, int32_t device) { + this->ClearExpiredEntries(); + container_[m.get()].ref = m; + if (device != GenericParameter::kCpuId) { + container_[m.get()].predictions.SetDevice(device); + } + return container_[m.get()]; +} + +PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) { + CHECK(container_.find(m) != container_.cend()); + CHECK(container_.at(m).ref.lock()) + << "[Internal error]: DMatrix: " << m << " has expired."; + return container_.at(m); +} + +decltype(PredictionContainer::container_) const& PredictionContainer::Container() { + this->ClearExpiredEntries(); + return container_; +} + void Predictor::Configure( const std::vector>& cfg) { } Predictor* Predictor::Create( - std::string const& name, GenericParameter const* generic_param, - std::shared_ptr> cache) { + std::string const& name, GenericParameter const* generic_param) { auto* e = ::dmlc::Registry::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown predictor type " << name; } - auto p_predictor = (e->body)(generic_param, cache); + auto p_predictor = (e->body)(generic_param); return p_predictor; } } // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 0b92b4878ea9..f3655c009385 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -10,6 +10,7 @@ #include "xgboost/learner.h" #include "../helpers.h" #include "../../../src/gbm/gbtree.h" +#include "xgboost/predictor.h" namespace xgboost { TEST(GBTree, SelectTreeMethod) { @@ -22,9 +23,8 @@ TEST(GBTree, SelectTreeMethod) { mparam.num_feature = kCols; mparam.num_output_group = 1; - std::vector > caches; std::unique_ptr p_gbm { - GradientBooster::Create("gbtree", &generic_param, &mparam, caches)}; + GradientBooster::Create("gbtree", &generic_param, &mparam)}; auto& gbtree = dynamic_cast (*p_gbm); // Test if `tree_method` can be set diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index a479622cbd85..a5c89eebb709 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -1,8 +1,11 @@ /*! - * Copyright 2016-2019 XGBoost contributors + * Copyright 2016-2020 XGBoost contributors */ #include #include +#include +#include +#include #include #include #include @@ -16,6 +19,7 @@ #include "../../src/data/simple_csr_source.h" #include "../../src/gbm/gbtree_model.h" +#include "xgboost/predictor.h" bool FileExists(const std::string& filename) { struct stat st; @@ -259,13 +263,19 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( } } -gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param) { - std::vector> trees; - trees.push_back(std::unique_ptr(new RegTree)); - (*trees.back())[0].SetLeaf(1.5f); - (*trees.back()).Stat(0).sum_hess = 1.0f; +gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) { gbm::GBTreeModel model(param); - model.CommitModel(std::move(trees), 0); + + for (size_t i = 0; i < n_classes; ++i) { + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + if (i == 0) { + (*trees.back())[0].SetLeaf(1.5f); + (*trees.back()).Stat(0).sum_hess = 1.0f; + } + model.CommitModel(std::move(trees), i); + } + return model; } @@ -273,8 +283,9 @@ std::unique_ptr CreateTrainedGBM( std::string name, Args kwargs, size_t kRows, size_t kCols, LearnerModelParam const* learner_model_param, GenericParameter const* generic_param) { + auto caches = std::make_shared< PredictionContainer >();; std::unique_ptr gbm { - GradientBooster::Create(name, generic_param, learner_model_param, {})}; + GradientBooster::Create(name, generic_param, learner_model_param)}; gbm->Configure(kwargs); auto pp_dmat = CreateDMatrix(kRows, kCols, 0); auto p_dmat = *pp_dmat; @@ -291,7 +302,9 @@ std::unique_ptr CreateTrainedGBM( h_gpair[i] = {static_cast(i), 1}; } - gbm->DoBoost(p_dmat.get(), &gpair, nullptr); + PredictionCacheEntry predts; + + gbm->DoBoost(p_dmat.get(), &gpair, &predts); delete pp_dmat; return gbm; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index ca124234337e..aaad7955ea14 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -16,16 +16,13 @@ #include #include -#include -#include #include -#include #include #include -#include #include "../../src/common/common.h" #include "../../src/common/hist_util.h" +#include "../../src/gbm/gbtree_model.h" #if defined(__CUDACC__) #include "../../src/data/ellpack_page.cuh" #endif @@ -42,6 +39,12 @@ #define GPUIDX -1 #endif +namespace xgboost { +class ObjFunction; +class Metric; +struct LearnerModelParam; +} + bool FileExists(const std::string& filename); int64_t GetFileSize(const std::string& filename); @@ -206,7 +209,7 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory()); -gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param); +gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes = 1); std::unique_ptr CreateTrainedGBM( std::string name, Args kwargs, size_t kRows, size_t kCols, diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 6ef39f4e66a2..9b13cf576047 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -1,4 +1,6 @@ -// Copyright by Contributors +/*! + * Copyright 2017-2020 XGBoost contributors + */ #include #include #include @@ -9,9 +11,8 @@ namespace xgboost { TEST(CpuPredictor, Basic) { auto lparam = CreateEmptyGenericParam(GPUIDX); - auto cache = std::make_shared>(); std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam, cache)); + std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); int kRows = 5; int kCols = 5; @@ -26,10 +27,11 @@ TEST(CpuPredictor, Basic) { auto dmat = CreateDMatrix(kRows, kCols, 0); // Test predict batch - HostDeviceVector out_predictions; + PredictionCacheEntry out_predictions; cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0); - std::vector& out_predictions_h = out_predictions.HostVector(); - for (size_t i = 0; i < out_predictions.Size(); i++) { + ASSERT_EQ(model.trees.size(), out_predictions.version); + std::vector& out_predictions_h = out_predictions.predictions.HostVector(); + for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { ASSERT_EQ(out_predictions_h[i], 1.5); } @@ -81,10 +83,9 @@ TEST(CpuPredictor, ExternalMemory) { std::string filename = tmpdir.path + "/big.libsvm"; std::unique_ptr dmat = CreateSparsePageDMatrix(12, 64, filename); auto lparam = CreateEmptyGenericParam(GPUIDX); - auto cache = std::make_shared>(); std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam, cache)); + std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); LearnerModelParam param; param.base_score = 0; @@ -94,10 +95,10 @@ TEST(CpuPredictor, ExternalMemory) { gbm::GBTreeModel model = CreateTestModel(¶m); // Test predict batch - HostDeviceVector out_predictions; + PredictionCacheEntry out_predictions; cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - std::vector &out_predictions_h = out_predictions.HostVector(); - ASSERT_EQ(out_predictions.Size(), dmat->Info().num_row_); + std::vector &out_predictions_h = out_predictions.predictions.HostVector(); + ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_); for (const auto& v : out_predictions_h) { ASSERT_EQ(v, 1.5); } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index ffc8743b42d9..48c68335b846 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -1,6 +1,5 @@ - /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2020 XGBoost contributors */ #include #include @@ -19,12 +18,11 @@ namespace predictor { TEST(GpuPredictor, Basic) { auto cpu_lparam = CreateEmptyGenericParam(-1); auto gpu_lparam = CreateEmptyGenericParam(0); - auto cache = std::make_shared>(); std::unique_ptr gpu_predictor = - std::unique_ptr(Predictor::Create("gpu_predictor", &gpu_lparam, cache)); + std::unique_ptr(Predictor::Create("gpu_predictor", &gpu_lparam)); std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &cpu_lparam, cache)); + std::unique_ptr(Predictor::Create("cpu_predictor", &cpu_lparam)); gpu_predictor->Configure({}); cpu_predictor->Configure({}); @@ -41,16 +39,17 @@ TEST(GpuPredictor, Basic) { gbm::GBTreeModel model = CreateTestModel(¶m); // Test predict batch - HostDeviceVector gpu_out_predictions; - HostDeviceVector cpu_out_predictions; + PredictionCacheEntry gpu_out_predictions; + PredictionCacheEntry cpu_out_predictions; gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0); + ASSERT_EQ(model.trees.size(), gpu_out_predictions.version); cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0); - std::vector& gpu_out_predictions_h = gpu_out_predictions.HostVector(); - std::vector& cpu_out_predictions_h = cpu_out_predictions.HostVector(); + std::vector& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector(); + std::vector& cpu_out_predictions_h = cpu_out_predictions.predictions.HostVector(); float abs_tolerance = 0.001; - for (int j = 0; j < gpu_out_predictions.Size(); j++) { + for (int j = 0; j < gpu_out_predictions.predictions.Size(); j++) { ASSERT_NEAR(gpu_out_predictions_h[j], cpu_out_predictions_h[j], abs_tolerance); } delete dmat; @@ -59,9 +58,8 @@ TEST(GpuPredictor, Basic) { TEST(gpu_predictor, ExternalMemoryTest) { auto lparam = CreateEmptyGenericParam(0); - auto cache = std::make_shared>(); std::unique_ptr gpu_predictor = - std::unique_ptr(Predictor::Create("gpu_predictor", &lparam, cache)); + std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); gpu_predictor->Configure({}); LearnerModelParam param; @@ -70,7 +68,7 @@ TEST(gpu_predictor, ExternalMemoryTest) { param.num_output_group = n_classes; param.base_score = 0.5; - gbm::GBTreeModel model = CreateTestModel(¶m); + gbm::GBTreeModel model = CreateTestModel(¶m, n_classes); std::vector> dmats; dmlc::TemporaryDirectory tmpdir; std::string file0 = tmpdir.path + "/big_0.libsvm"; @@ -82,10 +80,10 @@ TEST(gpu_predictor, ExternalMemoryTest) { for (const auto& dmat: dmats) { dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); - HostDeviceVector out_predictions; + PredictionCacheEntry out_predictions; gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes); - const std::vector &host_vector = out_predictions.ConstHostVector(); + EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes); + const std::vector &host_vector = out_predictions.predictions.ConstHostVector(); for (int i = 0; i < host_vector.size() / n_classes; i++) { ASSERT_EQ(host_vector[i * n_classes], 2.0); ASSERT_EQ(host_vector[i * n_classes + 1], 0.5); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc new file mode 100644 index 000000000000..927697e6bda2 --- /dev/null +++ b/tests/cpp/predictor/test_predictor.cc @@ -0,0 +1,33 @@ +/*! + * Copyright 2020 by Contributors + */ + +#include +#include +#include +#include + +#include "../helpers.h" +#include "xgboost/generic_parameters.h" + +namespace xgboost { +TEST(Predictor, PredictionCache) { + size_t constexpr kRows = 16, kCols = 4; + + PredictionContainer container; + DMatrix* m; + // Add a cache that is immediately expired. + auto add_cache = [&]() { + auto *pp_dmat = CreateDMatrix(kRows, kCols, 0); + auto p_dmat = *pp_dmat; + container.Cache(p_dmat, GenericParameter::kCpuId); + m = p_dmat.get(); + delete pp_dmat; + }; + + add_cache(); + ASSERT_EQ(container.Container().size(), 0); + add_cache(); + EXPECT_ANY_THROW(container.Entry(m)); +} +} // namespace xgboost diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index d27c50d71d5c..a38f9a26cac6 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -19,11 +19,12 @@ @contextmanager def captured_output(): - """ - Reassign stdout temporarily in order to test printed statements - Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python + """Reassign stdout temporarily in order to test printed statements + Taken from: + https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python Also works for pytest. + """ new_out, new_err = StringIO(), StringIO() old_out, old_err = sys.stdout, sys.stderr @@ -39,13 +40,20 @@ class TestBasic(unittest.TestCase): def test_basic(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') - param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} # specify validations set to watch performance - watchlist = [(dtest, 'eval'), (dtrain, 'train')] + watchlist = [(dtrain, 'train')] num_round = 2 - bst = xgb.train(param, dtrain, num_round, watchlist) - # this is prediction + bst = xgb.train(param, dtrain, num_round, watchlist, verbose_eval=True) + + preds = bst.predict(dtrain) + labels = dtrain.get_label() + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + # error must be smaller than 10% + assert err < 0.1 + preds = bst.predict(dtest) labels = dtest.get_label() err = sum(1 for i in range(len(preds)) From 2bce357e6a52fc5b74706643b9bd5b43bc3a80aa Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 11 Feb 2020 23:47:34 +0800 Subject: [PATCH 2/3] Revert C API change. --- src/c_api/c_api.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 849db1ab22fe..ccc49d91a402 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -11,8 +11,8 @@ #include #include + #include "xgboost/data.h" -#include "xgboost/host_device_vector.h" #include "xgboost/learner.h" #include "xgboost/c_api.h" #include "xgboost/logging.h" @@ -146,7 +146,7 @@ struct XGBAPIThreadLocalEntry { /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; /*! \brief returning float vector. */ - HostDeviceVector ret_vec_float; + std::vector ret_vec_float; /*! \brief temp variable of gradient pairs. */ std::vector tmp_gpair; }; @@ -553,22 +553,24 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, int32_t training, xgboost::bst_ulong *len, const bst_float **out_result) { + std::vector& preds = + XGBAPIThreadLocalStore::Get()->ret_vec_float; API_BEGIN(); CHECK_HANDLE(); - HostDeviceVector& preds = - XGBAPIThreadLocalStore::Get()->ret_vec_float; auto *bst = static_cast(handle); + HostDeviceVector tmp_preds; bst->Predict( *static_cast*>(dmat), (option_mask & 1) != 0, - &preds, ntree_limit, + &tmp_preds, ntree_limit, static_cast(training), (option_mask & 2) != 0, (option_mask & 4) != 0, (option_mask & 8) != 0, (option_mask & 16) != 0); - *out_result = dmlc::BeginPtr(preds.HostVector()); - *len = static_cast(preds.Size()); + preds = tmp_preds.HostVector(); + *out_result = dmlc::BeginPtr(preds); + *len = static_cast(preds.size()); API_END(); } From 4fcc43e9ea518e5cef364c3f26a2cf0b62b47ff2 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 12 Feb 2020 13:57:01 +0800 Subject: [PATCH 3/3] Prediction vector per matrix. --- src/learner.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 4b3eeedafbbd..c055ed6beb5b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -748,20 +748,19 @@ class LearnerImpl : public Learner { metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); } for (size_t i = 0; i < data_sets.size(); ++i) { - std::shared_ptr dmat = data_sets[i]; - auto& predt = this->cache_.Cache(dmat, generic_parameters_.gpu_id); - this->ValidateDMatrix(dmat.get()); - this->PredictRaw(dmat.get(), &predt, false); + std::shared_ptr m = data_sets[i]; + auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id); + this->ValidateDMatrix(m.get()); + this->PredictRaw(m.get(), &predt, false); - predictions_.SetDevice(generic_parameters_.gpu_id); - predictions_.Resize(predt.predictions.Size()); - predictions_.Copy(predt.predictions); + auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; + out.Resize(predt.predictions.Size()); + out.Copy(predt.predictions); - obj_->EvalTransform(&predictions_); + obj_->EvalTransform(&out); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(predictions_, data_sets[i]->Info(), - tparam_.dsplit == DataSplitMode::kRow); + << ev->Eval(out, m->Info(), tparam_.dsplit == DataSplitMode::kRow); } } @@ -1000,7 +999,7 @@ class LearnerImpl : public Learner { PredictionContainer cache_; /*! \brief Temporary storage to prediction. Useful for storing data transformed by * objective function */ - HostDeviceVector predictions_; + PredictionContainer output_predictions_; common::Monitor monitor_;