Skip to content

Commit

Permalink
Span class. (#3548)
Browse files Browse the repository at this point in the history
* Add basic Span class based on ISO++20.

* Use Span<Entry const> instead of Inst in SparsePage.

* Add DeviceSpan in HostDeviceVector, use it in regression obj.
  • Loading branch information
trivialfis authored and RAMitchell committed Aug 14, 2018
1 parent 2b7a1c5 commit 2c50278
Show file tree
Hide file tree
Showing 28 changed files with 1,927 additions and 138 deletions.
33 changes: 12 additions & 21 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"

namespace xgboost {
// forward declare learner.
Expand Down Expand Up @@ -133,7 +134,7 @@ struct Entry {
/*!
* \brief constructor with index and value
* \param index The feature or row index.
* \param fvalue THe feature value.
* \param fvalue The feature value.
*/
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
Expand All @@ -155,24 +156,14 @@ class SparsePage {
std::vector<Entry> data;

size_t base_rowid;

/*! \brief an instance of sparse vector in the batch */
struct Inst {
/*! \brief pointer to the elements*/
const Entry *data{nullptr};
/*! \brief length of the instance */
bst_uint length{0};
/*! \brief constructor */
Inst() = default;
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
/*! \brief get i-th pair in the sparse vector*/
inline const Entry& operator[](size_t i) const {
return data[i];
}
};
using Inst = common::Span<Entry const>;

/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return {data.data() + offset[i], static_cast<bst_uint>(offset[i + 1] - offset[i])};
return {data.data() + offset[i],
static_cast<Inst::index_type>(offset[i + 1] - offset[i])};
}

/*! \brief constructor */
Expand Down Expand Up @@ -234,12 +225,12 @@ class SparsePage {
* \param inst an instance row
*/
inline void Push(const Inst &inst) {
offset.push_back(offset.back() + inst.length);
offset.push_back(offset.back() + inst.size());
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(Entry) * inst.length);
data.resize(begin + inst.size());
if (inst.size() != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data(),
sizeof(Entry) * inst.size());
}
}

Expand Down Expand Up @@ -328,7 +319,7 @@ class DMatrix {
* \brief check if column access is supported, if not, initialize column access.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \param sorted If column features should be in sorted order
* \param sorted If column features should be in sorted order
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,14 @@ inline void RegTree::FVec::Init(size_t size) {
}

inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
}
}

inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;
}
Expand Down
8 changes: 4 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int ridx = idxset[i];
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
ret.page_.data.insert(ret.page_.data.end(), inst.data,
inst.data + inst.length);
ret.page_.offset.push_back(ret.page_.offset.back() + inst.length);
ret.info.num_nonzero_ += inst.length;
ret.page_.data.insert(ret.page_.data.end(), inst.data(),
inst.data() + inst.size());
ret.page_.offset.push_back(ret.page_.offset.back() + inst.size());
ret.info.num_nonzero_ += inst.size();

if (src.info.labels_.size() != 0) {
ret.info.labels_.push_back(src.info.labels_[ridx]);
Expand Down
14 changes: 8 additions & 6 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
SparsePage::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
for (auto& ins : inst) {
if (ins.index >= begin && ins.index < end) {
sketchs[ins.index].Push(ins.fvalue, info.GetWeight(ridx));
}
}
}
Expand Down Expand Up @@ -140,7 +140,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
auto &batch = iter->Value();
const size_t rbegin = row_ptr.size() - 1;
for (size_t i = 0; i < batch.Size(); ++i) {
row_ptr.push_back(batch[i].length + row_ptr.back());
row_ptr.push_back(batch[i].size() + row_ptr.back());
}
index.resize(row_ptr.back());

Expand All @@ -154,9 +154,11 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
SparsePage::Inst inst = batch[i];
CHECK_EQ(ibegin + inst.length, iend);
for (bst_uint j = 0; j < inst.length; ++j) {

CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
uint32_t idx = cut.GetBinIdx(inst[j]);

index[ibegin + j] = idx;
++hit_count_tloc_[tid * nbins + idx];
}
Expand Down
5 changes: 5 additions & 0 deletions src/common/host_device_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ GPUSet HostDeviceVector<T>::Devices() const { return GPUSet::Empty(); }
template <typename T>
T* HostDeviceVector<T>::DevicePointer(int device) { return nullptr; }

template <typename T>
common::Span<T> HostDeviceVector<T>::DeviceSpan(int device) {
return common::Span<T>();
}

template <typename T>
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->data_h_; }

Expand Down
12 changes: 12 additions & 0 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ struct HostDeviceVectorImpl {
return shards_[devices_.Index(device)].data_.data().get();
}

common::Span<T> DeviceSpan(int device) {
CHECK(devices_.Contains(device));
LazySyncDevice(device);
return { shards_[devices_.Index(device)].data_.data().get(),
static_cast<typename common::Span<T>::index_type>(Size()) };
}

size_t DeviceSize(int device) {
CHECK(devices_.Contains(device));
LazySyncDevice(device);
Expand Down Expand Up @@ -323,6 +330,11 @@ GPUSet HostDeviceVector<T>::Devices() const { return impl_->Devices(); }
template <typename T>
T* HostDeviceVector<T>::DevicePointer(int device) { return impl_->DevicePointer(device); }

template <typename T>
common::Span<T> HostDeviceVector<T>::DeviceSpan(int device) {
return impl_->DeviceSpan(device);
}

template <typename T>
size_t HostDeviceVector<T>::DeviceStart(int device) { return impl_->DeviceStart(device); }

Expand Down
3 changes: 3 additions & 0 deletions src/common/host_device_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <initializer_list>
#include <vector>

#include "span.h"

// only include thrust-related files if host_device_vector.h
// is included from a .cu file
#ifdef __CUDACC__
Expand Down Expand Up @@ -117,6 +119,7 @@ class HostDeviceVector {
size_t Size() const;
GPUSet Devices() const;
T* DevicePointer(int device);
common::Span<T> DeviceSpan(int device);

T* HostPointer() { return HostVector().data(); }
size_t DeviceStart(int device);
Expand Down
Loading

0 comments on commit 2c50278

Please sign in to comment.