diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 40136c5c7991..96a48044d0e1 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -37,6 +37,7 @@ class Column { size_t Size() const { return len_; } uint32_t GetGlobalBinIdx(size_t idx) const { return index_base_ + index_[idx]; } uint32_t GetFeatureBinIdx(size_t idx) const { return index_[idx]; } + common::Span GetFeatureBinIdxPtr() const { return { index_, len_ }; } // column.GetFeatureBinIdx(idx) + column.GetBaseIdx(idx) == // column.GetGlobalBinIdx(idx) uint32_t GetBaseIdx() const { return index_base_; } @@ -186,8 +187,8 @@ class ColumnMatrix { std::vector feature_counts_; std::vector type_; - SimpleArray index_; // index_: may store smaller integers; needs padding - SimpleArray row_ind_; + std::vector index_; // index_: may store smaller integers; needs padding + std::vector row_ind_; std::vector boundary_; // index_base_[fid]: least bin id for feature fid diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c67b2b0e7a45..b57ff5e4bca9 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -672,7 +672,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, } /*! - * \brief fill a histogram by zeroes + * \brief fill a histogram by zeros in range [begin, end) */ void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats)); @@ -719,43 +719,144 @@ void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, } } - -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { - const size_t* rid = row_indices.begin; - const size_t nrows = row_indices.Size(); - const uint32_t* index = gmat.index.data(); - const size_t* row_ptr = gmat.row_ptr.data(); +struct Prefetch { + public: + static constexpr size_t kCacheLineSize = 64; + static constexpr size_t kPrefetchOffset = 10; + static constexpr size_t kPrefetchStep = + kCacheLineSize / sizeof(decltype(GHistIndexMatrix::index)::value_type); + + private: + static constexpr size_t kNoPrefetchSize = + kPrefetchOffset + kCacheLineSize / + sizeof(decltype(GHistIndexMatrix::row_ptr)::value_type); + + public: + static size_t NoPrefetchSize(size_t rows) { + return std::min(rows, kNoPrefetchSize); + } +}; + +constexpr size_t Prefetch::kNoPrefetchSize; + +template +void BuildHistDenseKernel(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + const size_t n_features, + GHistRow hist) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; const float* pgh = reinterpret_cast(gpair.data()); + const uint32_t* gradient_index = gmat.index.data(); + FPType* hist_data = reinterpret_cast(hist.data()); + + const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains + // 2 FP values: gradient and hessian. + // So we need to multiply each row-index/bin-index by 2 + // to work with gradient pairs as a singe row FP array + + for (size_t i = 0; i < size; ++i) { + const size_t icol_start = rid[i] * n_features; + const size_t idx_gh = two * rid[i]; + + if (do_prefetch) { + const size_t icol_start_prefetch = rid[i + Prefetch::kPrefetchOffset] * n_features; - double* hist_data = reinterpret_cast(hist.data()); + PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); + for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features; + j += Prefetch::kPrefetchStep) { + PREFETCH_READ_T0(gradient_index + j); + } + } + + for (size_t j = icol_start; j < icol_start + n_features; ++j) { + const uint32_t idx_bin = two * gradient_index[j]; + + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin+1] += pgh[idx_gh+1]; + } + } +} + +template +void BuildHistSparseKernel(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const float* pgh = reinterpret_cast(gpair.data()); + const uint32_t* gradient_index = gmat.index.data(); + const size_t* row_ptr = gmat.row_ptr.data(); + FPType* hist_data = reinterpret_cast(hist.data()); - const size_t cache_line_size = 64; - const size_t prefetch_offset = 10; - size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); - no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; + const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains + // 2 FP values: gradient and hessian. + // So we need to multiply each row-index/bin-index by 2 + // to work with gradient pairs as a singe row FP array - for (size_t i = 0; i < nrows; ++i) { + for (size_t i = 0; i < size; ++i) { const size_t icol_start = row_ptr[rid[i]]; const size_t icol_end = row_ptr[rid[i]+1]; + const size_t idx_gh = two * rid[i]; + + if (do_prefetch) { + const size_t icol_start_prftch = row_ptr[rid[i+Prefetch::kPrefetchOffset]]; + const size_t icol_end_prefect = row_ptr[rid[i+Prefetch::kPrefetchOffset]+1]; - if (i < nrows - no_prefetch_size) { - PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]); - PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); + for (size_t j = icol_start_prftch; j < icol_end_prefect; j+=Prefetch::kPrefetchStep) { + PREFETCH_READ_T0(gradient_index + j); + } } for (size_t j = icol_start; j < icol_end; ++j) { - const uint32_t idx_bin = 2*index[j]; - const size_t idx_gh = 2*rid[i]; - - hist_data[idx_bin] += pgh[idx_gh]; + const uint32_t idx_bin = two * gradient_index[j]; + hist_data[idx_bin] += pgh[idx_gh]; hist_data[idx_bin+1] += pgh[idx_gh+1]; } } } +template +void BuildHistKernel(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, const bool isDense, GHistRow hist) { + if (row_indices.Size() && isDense) { + const size_t* row_ptr = gmat.row_ptr.data(); + const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; + BuildHistDenseKernel(gpair, row_indices, gmat, n_features, hist); + } else { + BuildHistSparseKernel(gpair, row_indices, gmat, hist); + } +} + +void GHistBuilder::BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist, + bool isDense) { + using FPType = decltype(tree::GradStats::sum_grad); + const size_t nrows = row_indices.Size(); + const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); + + // if need to work with all rows from bin-matrix (e.g. root node) + const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); + + if (contiguousBlock) { + // contiguous memory access, built-in HW prefetching is enough + BuildHistKernel(gpair, row_indices, gmat, isDense, hist); + } else { + const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); + const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); + + BuildHistKernel(gpair, span1, gmat, isDense, hist); + // no prefetching to avoid loading extra memory + BuildHistKernel(gpair, span2, gmat, isDense, hist); + } +} + void GHistBuilder::BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexBlockMatrix& gmatb, diff --git a/src/common/hist_util.h b/src/common/hist_util.h index a1103f2a520b..aa0c57ab4034 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017 by Contributors + * Copyright 2017-2020 by Contributors * \file hist_util.h * \brief Utility for fast histogram aggregation * \author Philip Cho, Tianqi Chen @@ -25,75 +25,6 @@ namespace xgboost { namespace common { - -/* - * \brief A thin wrapper around dynamically allocated C-style array. - * Make sure to call resize() before use. - */ -template -struct SimpleArray { - ~SimpleArray() { - std::free(ptr_); - ptr_ = nullptr; - } - - void resize(size_t n) { - T* ptr = static_cast(std::malloc(n * sizeof(T))); - CHECK(ptr) << "Failed to allocate memory"; - if (ptr_) { - std::memcpy(ptr, ptr_, n_ * sizeof(T)); - std::free(ptr_); - } - ptr_ = ptr; - n_ = n; - } - - T& operator[](size_t idx) { - return ptr_[idx]; - } - - T& operator[](size_t idx) const { - return ptr_[idx]; - } - - size_t size() const { - return n_; - } - - T back() const { - return ptr_[n_-1]; - } - - T* data() { - return ptr_; - } - - const T* data() const { - return ptr_; - } - - - T* begin() { - return ptr_; - } - - const T* begin() const { - return ptr_; - } - - T* end() { - return ptr_ + n_; - } - - const T* end() const { - return ptr_ + n_; - } - - private: - T* ptr_ = nullptr; - size_t n_ = 0; -}; - /*! * \brief A single row in global histogram index. * Directly represent the global index in the histogram entry. @@ -161,7 +92,7 @@ class HistogramCuts { return idx; } - BinIdx SearchBin(Entry const& e) { + BinIdx SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); } }; @@ -261,8 +192,9 @@ size_t DeviceSketch(int device, /*! * \brief preprocessed global index matrix, in CSR format - * Transform floating values to integer index in histogram - * This is a global histogram index. + * + * Transform floating values to integer index in histogram This is a global histogram + * index for CPU histogram. On GPU ellpack page is used. */ struct GHistIndexMatrix { /*! \brief row pointer to rows by element position */ @@ -606,17 +538,15 @@ class ParallelGHistBuilder { */ class GHistBuilder { public: - // initialize builder - inline void Init(size_t nthread, uint32_t nbins) { - nthread_ = nthread; - nbins_ = nbins; - } + GHistBuilder() : nthread_{0}, nbins_{0} {} + GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} // construct a histogram via histogram aggregation void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist); + GHistRow hist, + bool isDense); // same, with feature grouping void BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, @@ -625,7 +555,7 @@ class GHistBuilder { // construct a histogram via subtraction trick void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); - uint32_t GetNumBins() { + uint32_t GetNumBins() const { return nbins_; } diff --git a/src/common/row_set.h b/src/common/row_set.h index 285988b159c3..179e07d5024f 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace xgboost { namespace common { @@ -29,7 +30,7 @@ class RowSetCollection { = default; Elem(const size_t* begin, const size_t* end, - int node_id) + int node_id = -1) : begin(begin), end(end), node_id(node_id) {} inline size_t Size() const { @@ -57,6 +58,13 @@ class RowSetCollection { << "access element that is not in the set"; return e; } + + /*! \brief return corresponding element set given the node_id */ + inline Elem& operator[](unsigned node_id) { + Elem& e = elem_of_each_node_[node_id]; + return e; + } + // clear up things inline void Clear() { elem_of_each_node_.clear(); @@ -83,25 +91,18 @@ class RowSetCollection { } // split rowset into two inline void AddSplit(unsigned node_id, - const std::vector& row_split_tloc, unsigned left_node_id, - unsigned right_node_id) { + unsigned right_node_id, + size_t n_left, + size_t n_right) { const Elem e = elem_of_each_node_[node_id]; - const auto nthread = static_cast(row_split_tloc.size()); CHECK(e.begin != nullptr); size_t* all_begin = dmlc::BeginPtr(row_indices_); size_t* begin = all_begin + (e.begin - all_begin); - size_t* it = begin; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); - it += row_split_tloc[tid].left.size(); - } - size_t* split_pt = it; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); - it += row_split_tloc[tid].right.size(); - } + CHECK_EQ(n_left + n_right, e.Size()); + CHECK_LE(begin + n_left, e.end); + CHECK_EQ(begin + n_left + n_right, e.end); if (left_node_id >= elem_of_each_node_.size()) { elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); @@ -110,12 +111,12 @@ class RowSetCollection { elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); } - elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id); - elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id); + elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id); + elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id); elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); } - // stores the row indices in the set + // stores the row indexes in the set std::vector row_indices_; private: @@ -123,6 +124,121 @@ class RowSetCollection { std::vector elem_of_each_node_; }; + +// The builder is required for samples partition to left and rights children for set of nodes +// Responsible for: +// 1) Effective memory allocation for intermediate results for multi-thread work +// 2) Merging partial results produced by threads into original row set (row_set_collection_) +// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature +template +class PartitionBuilder { + public: + template + void Init(const size_t n_tasks, size_t n_nodes, Func funcNTaks) { + left_right_nodes_sizes_.resize(n_nodes); + blocks_offsets_.resize(n_nodes+1); + + blocks_offsets_[0] = 0; + for (size_t i = 1; i < n_nodes+1; ++i) { + blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTaks(i-1); + } + + if (n_tasks > max_n_tasks_) { + mem_blocks_.resize(n_tasks); + max_n_tasks_ = n_tasks; + } + } + + common::Span GetLeftBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx).left(), end - begin }; + } + + common::Span GetRightBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx).right(), end - begin }; + } + + void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx).n_left = n_left; + } + + void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx).n_right = n_right; + } + + + size_t GetNLeftElems(int nid) const { + return left_right_nodes_sizes_[nid].first; + } + + size_t GetNRightElems(int nid) const { + return left_right_nodes_sizes_[nid].second; + } + + // Each thread has partial results for some set of tree-nodes + // The function decides order of merging partial results into final row set + void CalculateRowOffsets() { + for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { + size_t n_left = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j].n_offset_left = n_left; + n_left += mem_blocks_[j].n_left; + } + size_t n_right = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j].n_offset_right = n_left + n_right; + n_right += mem_blocks_[j].n_right; + } + left_right_nodes_sizes_[i] = {n_left, n_right}; + } + } + + void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { + size_t task_idx = GetTaskIdx(nid, begin); + + size_t* left_result = rows_indexes + mem_blocks_[task_idx].n_offset_left; + size_t* right_result = rows_indexes + mem_blocks_[task_idx].n_offset_right; + + const size_t* left = mem_blocks_[task_idx].left(); + const size_t* right = mem_blocks_[task_idx].right(); + + std::copy_n(left, mem_blocks_[task_idx].n_left, left_result); + std::copy_n(right, mem_blocks_[task_idx].n_right, right_result); + } + + protected: + size_t GetTaskIdx(int nid, size_t begin) { + return blocks_offsets_[nid] + begin / BlockSize; + } + + struct BlockInfo{ + size_t n_left; + size_t n_right; + + size_t n_offset_left; + size_t n_offset_right; + + size_t* left() { + return &left_data_[0]; + } + + size_t* right() { + return &right_data_[0]; + } + private: + alignas(128) size_t left_data_[BlockSize]; + alignas(128) size_t right_data_[BlockSize]; + }; + std::vector> left_right_nodes_sizes_; + std::vector blocks_offsets_; + std::vector mem_blocks_; + size_t max_n_tasks_ = 0; +}; + + } // namespace common } // namespace xgboost diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index ff7c4667636c..3f32173b0e21 100755 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -9,6 +9,8 @@ #include #include +#include "xgboost/logging.h" + namespace xgboost { namespace common { @@ -20,11 +22,11 @@ class Range1d { CHECK_LT(begin, end); } - size_t begin() { + size_t begin() const { // NOLINT return begin_; } - size_t end() { + size_t end() const { // NOLINT return end_; } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 11a7bdaf2123..99689ee8df9e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -239,17 +239,14 @@ void QuantileHistMaker::Builder::BuildNodeStats( builder_monitor_.Stop("BuildNodeStats"); } -void QuantileHistMaker::Builder::EvaluateSplits( - const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, - int *num_leaves, - int depth, - unsigned *timestamp, - std::vector *temp_qexpand_depth) { - EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree); - +void QuantileHistMaker::Builder::AddSplitsToTree( + const GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth) { for (auto const& entry : qexpand_depth_wise_) { int nid = entry.nid; @@ -258,7 +255,17 @@ void QuantileHistMaker::Builder::EvaluateSplits( (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); + nodes_for_apply_split->push_back(entry); + + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + int left_id = (*p_tree)[nid].LeftChild(); int right_id = (*p_tree)[nid].RightChild(); temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, @@ -271,6 +278,24 @@ void QuantileHistMaker::Builder::EvaluateSplits( } } + +void QuantileHistMaker::Builder::EvaluateAndApplySplits( + const GHistIndexMatrix &gmat, + const ColumnMatrix &column_matrix, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth) { + EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree); + + std::vector nodes_for_apply_split; + AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp, + &nodes_for_apply_split, temp_qexpand_depth); + + ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, p_tree); +} + // Split nodes to 2 sets depending on amount of rows in each node // Histograms for small nodes will be built explicitly // Histograms for big nodes will be built by 'Subtraction Trick' @@ -335,7 +360,7 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise( SyncHistograms(starting_index, sync_count, p_tree); BuildNodeStats(gmat, p_fmat, p_tree, gpair_h); - EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, ×tamp, + EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp, &temp_qexpand_depth); // clean up qexpand_depth_wise_.clear(); @@ -367,7 +392,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree); - this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree); + this->EvaluateSplits({node}, gmat, hist_, *p_tree); node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; qexpand_loss_guided_->push(node); @@ -377,12 +402,19 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( const ExpandEntry candidate = qexpand_loss_guided_->top(); const int nid = candidate.nid; qexpand_loss_guided_->pop(); - if (candidate.loss_chg <= kRtEps - || (param_.max_depth > 0 && candidate.depth == param_.max_depth) - || (param_.max_leaves > 0 && num_leaves == param_.max_leaves) ) { + if (candidate.IsValid(param_, num_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + + this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree); const int cleft = (*p_tree)[nid].LeftChild(); const int cright = (*p_tree)[nid].RightChild(); @@ -410,7 +442,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( snode_[cleft].weight, snode_[cright].weight); interaction_constraints_.Split(nid, featureid, cleft, cright); - this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree); + this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree); left_node.loss_chg = snode_[cleft].best.loss_chg; right_node.loss_chg = snode_[cright].best.loss_chg; @@ -473,7 +505,14 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( CHECK_GT(out_preds.size(), 0U); - for (const RowSetCollection::Elem rowset : row_set_collection_) { + size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); + + common::BlockedSpace2d space(n_nodes, [&](size_t node) { + return row_set_collection_[node].Size(); + }, 1024); + + common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { + const RowSetCollection::Elem rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr) { int nid = rowset.node_id; bst_float leaf_value; @@ -487,11 +526,11 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } leaf_value = (*p_last_tree_)[nid].LeafValue(); - for (const size_t* it = rowset.begin; it < rowset.end; ++it) { + for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds[*it] += leaf_value; } } - } + }); builder_monitor_.Stop("UpdatePredictionCache"); return true; @@ -526,7 +565,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, { this->nthread_ = omp_get_num_threads(); } - hist_builder_.Init(this->nthread_, nbins); + hist_builder_ = GHistBuilder(this->nthread_, nbins); std::vector& row_indices = row_set_collection_.row_indices_; row_indices.resize(info.num_row_); @@ -674,12 +713,11 @@ bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e, } // nodes_set - set of nodes to be processed in parallel -void QuantileHistMaker::Builder::EvaluateSplit(const std::vector& nodes_set, +void QuantileHistMaker::Builder::EvaluateSplits(const std::vector& nodes_set, const GHistIndexMatrix& gmat, const HistCollection& hist, - const DMatrix& fmat, const RegTree& tree) { - builder_monitor_.Start("EvaluateSplit"); + builder_monitor_.Start("EvaluateSplits"); const size_t n_nodes_in_set = nodes_set.size(); const size_t nthread = std::max(1, this->nthread_); @@ -714,11 +752,11 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector& n for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) { const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; if (interaction_constraints_.Query(nid, fid)) { - auto grad_stats = this->EnumerateSplit<+1>(gmat, node_hist, snode_[nid], fmat.Info(), - &best_split_tloc_[nthread*nid_in_set + tid], fid, nid); + auto grad_stats = this->EnumerateSplit<+1>(gmat, node_hist, snode_[nid], + &best_split_tloc_[nthread*nid_in_set + tid], fid, nid); if (SplitContainsMissingValues(grad_stats, snode_[nid])) { - this->EnumerateSplit<-1>(gmat, node_hist, snode_[nid], fmat.Info(), - &best_split_tloc_[nthread*nid_in_set + tid], fid, nid); + this->EnumerateSplit<-1>(gmat, node_hist, snode_[nid], + &best_split_tloc_[nthread*nid_in_set + tid], fid, nid); } } } @@ -732,198 +770,240 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector& n } } - builder_monitor_.Stop("EvaluateSplit"); + builder_monitor_.Stop("EvaluateSplits"); } -void QuantileHistMaker::Builder::ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree) { - builder_monitor_.Start("ApplySplit"); - // TODO(hcho3): support feature sampling by levels - - /* 1. Create child nodes */ - NodeEntry& e = snode_[nid]; - bst_float left_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; - bst_float right_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); - - /* 2. Categorize member rows */ - const auto nthread = static_cast(this->nthread_); - row_split_tloc_.resize(nthread); - for (bst_omp_uint i = 0; i < nthread; ++i) { - row_split_tloc_[i].left.clear(); - row_split_tloc_[i].right.clear(); - } - const bool default_left = (*p_tree)[nid].DefaultLeft(); - const bst_uint fid = (*p_tree)[nid].SplitIndex(); - const bst_float split_pt = (*p_tree)[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.Values()[i]) { - split_cond = static_cast(i); +// split row indexes (rid_span) to 2 parts (left_part, right_part) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle dense columns +// Analog of std::stable_partition, but in no-inplace manner +template +inline std::pair PartitionDenseKernel( + common::Span rid_span, common::Span idx_span, + const int32_t split_cond, const uint32_t offset, + common::Span left_part, common::Span right_part) { + const uint32_t* idx = idx_span.data(); + size_t* p_left_part = left_part.data(); + size_t* p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + + const uint32_t missing_val = std::numeric_limits::max(); + + for (auto rid : rid_span) { + if (idx[rid] == missing_val) { + if (default_left) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + } else { + if (static_cast(idx[rid] + offset) <= split_cond) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } } } - const auto& rowset = row_set_collection_[nid]; - - Column column = column_matrix.GetColumn(fid); - if (column.GetType() == xgboost::common::kDenseColumn) { - ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond, - default_left); - } else { - ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound, - upper_bound, split_cond, default_left); - } - - row_set_collection_.AddSplit( - nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild()); - builder_monitor_.Stop("ApplySplit"); + return {nleft_elems, nright_elems}; } -void QuantileHistMaker::Builder::ApplySplitDenseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - constexpr int kUnroll = 8; // loop unrolling factor - const size_t nrows = rowset.end - rowset.begin; - const size_t rest = nrows % kUnroll; - -#pragma omp parallel for num_threads(nthread_) schedule(static) - for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) { - const bst_uint tid = omp_get_thread_num(); - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - size_t rid[kUnroll]; - uint32_t rbin[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - rid[k] = rowset.begin[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - rbin[k] = column.GetFeatureBinIdx(rid[k]); - } - for (int k = 0; k < kUnroll; ++k) { // NOLINT - if (rbin[k] == std::numeric_limits::max()) { // missing value - if (default_left) { - left.push_back(rid[k]); - } else { - right.push_back(rid[k]); +// Split row indexes (rid_span) to 2 parts (left_part, right_part) depending +// on comparison of indexes values (idx_span) and split point (split_cond). +// Handle sparse columns +template +inline std::pair PartitionSparseKernel( + common::Span rid_span, const int32_t split_cond, const Column& column, + common::Span left_part, common::Span right_part) { + size_t* p_left_part = left_part.data(); + size_t* p_right_part = right_part.data(); + + size_t nleft_elems = 0; + size_t nright_elems = 0; + + if (rid_span.size()) { // ensure that rid_span is nonempty range + // search first nonzero row with index >= rid_span.front() + const size_t* p = std::lower_bound(column.GetRowData(), + column.GetRowData() + column.Size(), + rid_span.front()); + + if (p != column.GetRowData() + column.Size() && *p <= rid_span.back()) { + size_t cursor = p - column.GetRowData(); + + for (auto rid : rid_span) { + while (cursor < column.Size() + && column.GetRowIdx(cursor) < rid + && column.GetRowIdx(cursor) <= rid_span.back()) { + ++cursor; } - } else { - if (static_cast(rbin[k] + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid[k]); + if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { + const uint32_t rbin = column.GetFeatureBinIdx(cursor); + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + ++cursor; } else { - right.push_back(rid[k]); + // missing value + if (default_left) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } } } - } - } - for (size_t i = nrows - rest; i < nrows; ++i) { - auto& left = row_split_tloc[nthread_-1].left; - auto& right = row_split_tloc[nthread_-1].right; - const size_t rid = rowset.begin[i]; - const uint32_t rbin = column.GetFeatureBinIdx(rid); - if (rbin == std::numeric_limits::max()) { // missing value + } else { // all rows in rid_span have missing values if (default_left) { - left.push_back(rid); + std::copy(rid_span.begin(), rid_span.end(), p_left_part); + nleft_elems = rid_span.size(); } else { - right.push_back(rid); + std::copy(rid_span.begin(), rid_span.end(), p_right_part); + nright_elems = rid_span.size(); } + } + } + + return {nleft_elems, nright_elems}; +} + +void QuantileHistMaker::Builder::PartitionKernel( + const size_t node_in_set, const size_t nid, common::Range1d range, + const int32_t split_cond, const ColumnMatrix& column_matrix, + const GHistIndexMatrix& gmat, const RegTree& tree) { + const size_t* rid = row_set_collection_[nid].begin; + common::Span rid_span(rid + range.begin(), rid + range.end()); + common::Span left = partition_builder_.GetLeftBuffer(node_in_set, + range.begin(), range.end()); + common::Span right = partition_builder_.GetRightBuffer(node_in_set, + range.begin(), range.end()); + const bst_uint fid = tree[nid].SplitIndex(); + const bool default_left = tree[nid].DefaultLeft(); + const auto column = column_matrix.GetColumn(fid); + const uint32_t offset = column.GetBaseIdx(); + common::Span idx_spin = column.GetFeatureBinIdxPtr(); + + std::pair child_nodes_sizes; + + if (column.GetType() == xgboost::common::kDenseColumn) { + if (default_left) { + child_nodes_sizes = PartitionDenseKernel( + rid_span, idx_spin, split_cond, offset, left, right); } else { - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); - } else { - right.push_back(rid); - } + child_nodes_sizes = PartitionDenseKernel( + rid_span, idx_spin, split_cond, offset, left, right); + } + } else { + if (default_left) { + child_nodes_sizes = PartitionSparseKernel(rid_span, split_cond, column, left, right); + } else { + child_nodes_sizes = PartitionSparseKernel(rid_span, split_cond, column, left, right); } } + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + partition_builder_.SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + partition_builder_.SetNRightElems(node_in_set, range.begin(), range.end(), n_right); } -void QuantileHistMaker::Builder::ApplySplitSparseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - const size_t nrows = rowset.end - rowset.begin; - -#pragma omp parallel num_threads(nthread_) - { - const auto tid = static_cast(omp_get_thread_num()); - const size_t ibegin = tid * nrows / nthread_; - const size_t iend = (tid + 1) * nrows / nthread_; - if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range - // search first nonzero row with index >= rowset[ibegin] - const size_t* p = std::lower_bound(column.GetRowData(), - column.GetRowData() + column.Size(), - rowset.begin[ibegin]); - - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) { - size_t cursor = p - column.GetRowData(); - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - while (cursor < column.Size() - && column.GetRowIdx(cursor) < rid - && column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) { - ++cursor; - } - if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { - const uint32_t rbin = column.GetFeatureBinIdx(cursor); - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); - } else { - right.push_back(rid); - } - ++cursor; - } else { - // missing value - if (default_left) { - left.push_back(rid); - } else { - right.push_back(rid); - } - } - } - } else { // all rows in [ibegin, iend) have missing values - if (default_left) { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - left.push_back(rid); - } - } else { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - right.push_back(rid); - } - } +void QuantileHistMaker::Builder::FindSplitConditions(const std::vector& nodes, + const RegTree& tree, + const GHistIndexMatrix& gmat, + std::vector* split_conditions) { + const size_t n_nodes = nodes.size(); + split_conditions->resize(n_nodes); + + for (size_t i = 0; i < nodes.size(); ++i) { + const int32_t nid = nodes[i].nid; + const bst_uint fid = tree[nid].SplitIndex(); + const bst_float split_pt = tree[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut.Values()[i]) { + split_cond = static_cast(i); } } + (*split_conditions)[i] = split_cond; + } +} + +void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector& nodes, + RegTree* p_tree) { + const size_t n_nodes = nodes.size(); + for (size_t i = 0; i < n_nodes; ++i) { + const int32_t nid = nodes[i].nid; + const size_t n_left = partition_builder_.GetNLeftElems(i); + const size_t n_right = partition_builder_.GetNRightElems(i); + + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), + (*p_tree)[nid].RightChild(), n_left, n_right); } } + +void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + RegTree* p_tree) { + builder_monitor_.Start("ApplySplit"); + + // 1. Find split condition for each split + const size_t n_nodes = nodes.size(); + std::vector split_conditions; + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + + // 2.1 Create a blocked space of size SUM(samples in each node) + common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) { + int32_t nid = nodes[node_in_set].nid; + return row_set_collection_[nid].Size(); + }, kPartitionBlockSize); + + // 2.2 Initialize the partition builder + // allocate buffers for storage intermediate results by each thread + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + + // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node + // Store results in intermediate buffers from partition_builder_ + common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { + const int32_t nid = nodes[node_in_set].nid; + PartitionKernel(node_in_set, nid, r, + split_conditions[node_in_set], column_matrix, gmat, *p_tree); + }); + + // 3. Compute offsets to copy blocks of row-indexes + // from partition_builder_ to row_set_collection_ + partition_builder_.CalculateRowOffsets(); + + // 4. Copy elements from partition_builder_ to row_set_collection_ back + // with updated row-indexes for each tree-node + common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { + const int32_t nid = nodes[node_in_set].nid; + partition_builder_.MergeToArray(node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + + // 5. Add info about splits into row_set_collection_ + AddSplitsToRowSet(nodes, p_tree); + + builder_monitor_.Stop("ApplySplit"); +} + void QuantileHistMaker::Builder::InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, @@ -979,15 +1059,10 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, // Enumerate the split values of specific feature. // Returns the sum of gradients corresponding to the data points that contains a non-missing value // for the particular feature fid. -template +template GradStats QuantileHistMaker::Builder::EnumerateSplit( - const GHistIndexMatrix& gmat, - const GHistRow& hist, - const NodeEntry& snode, - const MetaInfo& info, - SplitEntry* p_best, - bst_uint fid, - bst_uint nodeID) { + const GHistIndexMatrix &gmat, const GHistRow &hist, const NodeEntry &snode, + SplitEntry *p_best, bst_uint fid, bst_uint nodeID) const { CHECK(d_step == +1 || d_step == -1); // aliases diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 18dd4ef1baa7..bef69a226b70 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -161,7 +161,7 @@ class QuantileHistMaker: public TreeUpdater { if (param_.enable_feature_grouping > 0) { hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); } else { - hist_builder_.BuildHist(gpair, row_indices, gmat, hist); + hist_builder_.BuildHist(gpair, row_indices, gmat, hist, data_layout_ != kSparseData); } } @@ -186,6 +186,13 @@ class QuantileHistMaker: public TreeUpdater { unsigned timestamp; ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, unsigned tstmp): nid(nid), sibling_nid(sibling_nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} + + bool IsValid(TrainParam const& param, int32_t num_leaves) const { + bool ret = loss_chg <= kRtEps || + (param.max_depth > 0 && this->depth == param.max_depth) || + (param.max_leaves > 0 && num_leaves == param.max_leaves); + return ret; + } }; // initialize temp data structure @@ -194,34 +201,27 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); - void EvaluateSplit(const std::vector& nodes_set, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const DMatrix& fmat, - const RegTree& tree); - - void ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree); - - void ApplySplitDenseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left); - - void ApplySplitSparseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left); + void EvaluateSplits(const std::vector& nodes_set, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const RegTree& tree); + + void ApplySplit(std::vector nodes, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + RegTree* p_tree); + + void PartitionKernel(const size_t node_in_set, const size_t nid, common::Range1d range, + const int32_t split_cond, + const ColumnMatrix& column_matrix, const GHistIndexMatrix& gmat, + const RegTree& tree); + + void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); + + + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrix& gmat, std::vector* split_conditions); void InitNewNode(int nid, const GHistIndexMatrix& gmat, @@ -232,15 +232,10 @@ class QuantileHistMaker: public TreeUpdater { // Enumerate the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains a non-missing // value for the particular feature fid. - template - GradStats EnumerateSplit( - const GHistIndexMatrix& gmat, - const GHistRow& hist, - const NodeEntry& snode, - const MetaInfo& info, - SplitEntry* p_best, - bst_uint fid, - bst_uint nodeID); + template + GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRow &hist, + const NodeEntry &snode, SplitEntry *p_best, + bst_uint fid, bst_uint nodeID) const; // if sum of statistics for non-missing values in the node // is equal to sum of statistics for all values: @@ -286,14 +281,22 @@ class QuantileHistMaker: public TreeUpdater { RegTree *p_tree, const std::vector &gpair_h); - void EvaluateSplits(const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, - int *num_leaves, - int depth, - unsigned *timestamp, - std::vector *temp_qexpand_depth); + void EvaluateAndApplySplits(const GHistIndexMatrix &gmat, + const ColumnMatrix &column_matrix, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth); + + void AddSplitsToTree( + const GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth); void ExpandWithLossGuide(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, @@ -335,6 +338,9 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr spliteval_; FeatureInteractionConstraintHost interaction_constraints_; + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + // back pointers to tree and data matrix const RegTree* p_last_tree_; DMatrix const* const p_last_fmat_; diff --git a/tests/cpp/common/test_partition_builder.cc b/tests/cpp/common/test_partition_builder.cc new file mode 100755 index 000000000000..6b51f2eaadc5 --- /dev/null +++ b/tests/cpp/common/test_partition_builder.cc @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include "../../../src/common/row_set.h" +#include "../helpers.h" + +namespace xgboost { +namespace common { + +TEST(PartitionBuilder, BasicTest) { + constexpr size_t kBlockSize = 16; + constexpr size_t kNodes = 5; + constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2; + + std::vector tasks = { 3, 5, 10, 1, 2 }; + + PartitionBuilder builder; + builder.Init(kTasks, kNodes, [&](size_t i) { + return tasks[i]; + }); + + std::vector rows_for_left_node = { 2, 12, 0, 16, 8 }; + + for(size_t nid = 0; nid < kNodes; ++nid) { + size_t value_left = 0; + size_t value_right = 0; + + size_t left_total = tasks[nid] * rows_for_left_node[nid]; + + for(size_t j = 0; j < tasks[nid]; ++j) { + size_t begin = kBlockSize*j; + size_t end = kBlockSize*(j+1); + + auto left = builder.GetLeftBuffer(nid, begin, end); + auto right = builder.GetRightBuffer(nid, begin, end); + + size_t n_left = rows_for_left_node[nid]; + size_t n_right = kBlockSize - rows_for_left_node[nid]; + + for(size_t i = 0; i < n_left; i++) { + left[i] = value_left++; + } + + for(size_t i = 0; i < n_right; i++) { + right[i] = left_total + value_right++; + } + + builder.SetNLeftElems(nid, begin, end, n_left); + builder.SetNRightElems(nid, begin, end, n_right); + } + } + builder.CalculateRowOffsets(); + + std::vector v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize); + + for(size_t nid = 0; nid < kNodes; ++nid) { + + for(size_t j = 0; j < tasks[nid]; ++j) { + builder.MergeToArray(nid, kBlockSize*j, v.data()); + } + + for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) { + ASSERT_EQ(v[j], j); + } + size_t n_left = builder.GetNLeftElems(nid); + size_t n_right = builder.GetNRightElems(nid); + + ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]); + ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]); + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 657f067997a2..1149025e5a75 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -213,7 +213,7 @@ class QuantileHistMock : public QuantileHistMaker { /* Now compare against result given by EvaluateSplit() */ ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, tree.GetDepth(0), snode_[0].best.loss_chg, 0); - RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree); + RealImpl::EvaluateSplits({node}, gmat, hist_, tree); ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);