Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add number of columns to native data iterator. #5202

Merged
merged 1 commit into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015~2020 by Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages.
Expand Down Expand Up @@ -40,6 +40,8 @@ typedef void *DataHolderHandle; // NOLINT(*)
typedef struct { // NOLINT(*)
/*! \brief number of rows in the minibatch */
size_t size;
/* \brief number of columns in the minibatch. */
size_t columns;
/*! \brief row pointer to the rows in the data */
#ifdef __APPLE__
/* Necessary as Java on MacOS defines jlong as long int
Expand Down
5 changes: 4 additions & 1 deletion jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
limitations under the License.
*/

#include <cstddef>
#include <cstdint>
#include <limits>
#include <rabit/c_api.h>
#include <xgboost/c_api.h>
#include <xgboost/base.h>
Expand Down Expand Up @@ -88,9 +90,10 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jintArray jindex = (jintArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
XGBoostBatchCSR cbatch;
cbatch.size = jenv->GetArrayLength(joffset) - 1;
cbatch.columns = std::numeric_limits<size_t>::max();
cbatch.offset = reinterpret_cast<jlong *>(
jenv->GetLongArrayElements(joffset, 0));
if (jlabel != nullptr) {
Expand Down
132 changes: 12 additions & 120 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <string>
#include <memory>


#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
Expand All @@ -24,116 +24,6 @@
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"

namespace xgboost {
// declare the data callback.
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
void *handle, XGBoostBatchCSR batch);

/*! \brief Native data iterator that takes callback to return data */
class NativeDataIter : public dmlc::Parser<uint32_t> {
public:
NativeDataIter(DataIterHandle data_handle,
XGBCallbackDataIterNext* next_callback)
: at_first_(true), bytes_read_(0),
data_handle_(data_handle), next_callback_(next_callback) {
}

// override functions
void BeforeFirst() override {
CHECK(at_first_) << "cannot reset NativeDataIter";
}

bool Next() override {
if ((*next_callback_)(data_handle_,
XGBoostNativeDataIterSetData,
this) != 0) {
at_first_ = false;
return true;
} else {
return false;
}
}

const dmlc::RowBlock<uint32_t>& Value() const override {
return block_;
}

size_t BytesRead() const override {
return bytes_read_;
}

// callback to set the data
void SetData(const XGBoostBatchCSR& batch) {
offset_.clear();
label_.clear();
weight_.clear();
index_.clear();
value_.clear();
offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
if (batch.label != nullptr) {
label_.insert(label_.end(), batch.label, batch.label + batch.size);
}
if (batch.weight != nullptr) {
weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
}
if (batch.index != nullptr) {
index_.insert(index_.end(), batch.index + offset_[0], batch.index + offset_.back());
}
if (batch.value != nullptr) {
value_.insert(value_.end(), batch.value + offset_[0], batch.value + offset_.back());
}
if (offset_[0] != 0) {
size_t base = offset_[0];
for (size_t& item : offset_) {
item -= base;
}
}
block_.size = batch.size;
block_.offset = dmlc::BeginPtr(offset_);
block_.label = dmlc::BeginPtr(label_);
block_.weight = dmlc::BeginPtr(weight_);
block_.qid = nullptr;
block_.field = nullptr;
block_.index = dmlc::BeginPtr(index_);
block_.value = dmlc::BeginPtr(value_);
bytes_read_ += offset_.size() * sizeof(size_t) +
label_.size() * sizeof(dmlc::real_t) +
weight_.size() * sizeof(dmlc::real_t) +
index_.size() * sizeof(uint32_t) +
value_.size() * sizeof(dmlc::real_t);
}

private:
// at the beinning.
bool at_first_;
// bytes that is read.
size_t bytes_read_;
// handle to the iterator,
DataIterHandle data_handle_;
// call back to get the data.
XGBCallbackDataIterNext* next_callback_;
// internal offset
std::vector<size_t> offset_;
// internal label data
std::vector<dmlc::real_t> label_;
// internal weight data
std::vector<dmlc::real_t> weight_;
// internal index.
std::vector<uint32_t> index_;
// internal value.
std::vector<dmlc::real_t> value_;
// internal Rowblock
dmlc::RowBlock<uint32_t> block_;
};

int XGBoostNativeDataIterSetData(
void *handle, XGBoostBatchCSR batch) {
API_BEGIN();
static_cast<xgboost::NativeDataIter*>(handle)->SetData(batch);
API_END();
}
} // namespace xgboost

using namespace xgboost; // NOLINT(*);

/*! \brief entry to to easily hold returning information */
Expand Down Expand Up @@ -186,21 +76,23 @@ int XGDMatrixCreateFromFile(const char *fname,
API_END();
}

int XGDMatrixCreateFromDataIter(
void* data_handle,
XGBCallbackDataIterNext* callback,
const char *cache_info,
DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromDataIter(
void *data_handle, // a Java interator
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
const char *cache_info, DMatrixHandle *out) {
API_BEGIN();

std::string scache;
if (cache_info != nullptr) {
scache = cache_info;
}
NativeDataIter parser(data_handle, callback);
data::FileAdapter adapter(&parser);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, scache));
xgboost::data::IteratorAdapter adapter(data_handle, callback);
*out = new std::shared_ptr<DMatrix> {
DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(),
1, scache
)
};
API_END();
}

Expand Down
20 changes: 5 additions & 15 deletions src/common/group_data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file group_data.h
* \brief this file defines utils to group data by integer keys
* Input: given input sequence (key,value), (k1,v1), (k2,v2)
Expand All @@ -14,6 +14,7 @@
#ifndef XGBOOST_COMMON_GROUP_DATA_H_
#define XGBOOST_COMMON_GROUP_DATA_H_

#include <cstddef>
#include <vector>
#include <algorithm>

Expand Down Expand Up @@ -44,15 +45,6 @@ class ParallelGroupBuilder {
size_t base_row_offset = 0)
: rptr_(*p_rptr),
data_(*p_data),
thread_rptr_(tmp_thread_rptr_),
base_row_offset_(base_row_offset) {}
ParallelGroupBuilder(std::vector<SizeType> *p_rptr,
std::vector<ValueType> *p_data,
std::vector<std::vector<SizeType> > *p_thread_rptr,
size_t base_row_offset = 0)
: rptr_(*p_rptr),
data_(*p_data),
thread_rptr_(*p_thread_rptr),
base_row_offset_(base_row_offset) {}

/*!
Expand All @@ -61,7 +53,7 @@ class ParallelGroupBuilder {
* \param max_key number of keys in the matrix, can be smaller than expected
* \param nthread number of thread that will be used in construction
*/
inline void InitBudget(std::size_t max_key, int nthread) {
void InitBudget(std::size_t max_key, int nthread) {
thread_rptr_.resize(nthread);
for (std::size_t i = 0; i < thread_rptr_.size(); ++i) {
thread_rptr_[i].resize(max_key - std::min(base_row_offset_, max_key));
Expand All @@ -74,7 +66,7 @@ class ParallelGroupBuilder {
* \param threadid the id of thread that calls this function
* \param nelem number of element budget add to this row
*/
inline void AddBudget(std::size_t key, int threadid, SizeType nelem = 1) {
void AddBudget(std::size_t key, int threadid, SizeType nelem = 1) {
std::vector<SizeType> &trptr = thread_rptr_[threadid];
size_t offset_key = key - base_row_offset_;
if (trptr.size() < offset_key + 1) {
Expand Down Expand Up @@ -129,9 +121,7 @@ class ParallelGroupBuilder {
/*! \brief index of nonzero entries in each row */
std::vector<ValueType> &data_;
/*! \brief thread local data structure */
std::vector<std::vector<SizeType> > &thread_rptr_;
/*! \brief local temp thread ptr, use this if not specified by the constructor */
std::vector<std::vector<SizeType> > tmp_thread_rptr_;
std::vector<std::vector<SizeType> > thread_rptr_;
/** \brief Used when rows being pushed into the builder are strictly above some number. */
size_t base_row_offset_;
};
Expand Down
Loading