Skip to content

Commit

Permalink
Add integer, serializable, CMake version.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 4, 2019
1 parent 4760394 commit 317e124
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 58 deletions.
6 changes: 6 additions & 0 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
#endif // GLIBC VERSION

#if defined(__GNUC__)
#define XGBOOST_EXPECT(cond, ret) __builtin_expect((cond), (ret))
#else
#define XGBOOST_EXPECT(cond, ret) (cond)
#endif // defined(__GNUC__)

/*!
* \brief Tag function as usable by device
*/
Expand Down
98 changes: 90 additions & 8 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_

#include <dmlc/io.h> // deprecated
#include <xgboost/logging.h>

#include <string>

#include <map>
Expand Down Expand Up @@ -185,7 +185,7 @@ class JsonNumber : public Value {

public:
JsonNumber() : Value(ValueKind::Number) {}
JsonNumber(double value) : Value(ValueKind::Number) { // NOLINT
JsonNumber(Float value) : Value(ValueKind::Number) { // NOLINT
number_ = value;
}

Expand All @@ -198,6 +198,7 @@ class JsonNumber : public Value {
Float const& getNumber() const & { return number_; }
Float& getNumber() & { return number_; }


bool operator==(Value const& rhs) const override;
Value& operator=(Value const& rhs) override;

Expand All @@ -206,6 +207,34 @@ class JsonNumber : public Value {
}
};

class JsonInteger : public Value {
public:
using Int = int64_t;

private:
Int integer_;

public:
JsonInteger() : Value(ValueKind::Integer), integer_{0} {}
template <typename IntT, typename std::enable_if<std::is_same<IntT, Int>::value>::type* = nullptr>
JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {}

Json& operator[](std::string const & key) override;
Json& operator[](int ind) override;

bool operator==(Value const& rhs) const override;
Value& operator=(Value const& rhs) override;

Int const& getInteger() && { return integer_; }
Int const& getInteger() const & { return integer_; }
Int& getInteger() & { return integer_; }
void Save(JsonWriter* writer) override;

static bool isClassOf(Value const* value) {
return value->Type() == ValueKind::Integer;
}
};

class JsonNull : public Value {
public:
JsonNull() : Value(ValueKind::Null) {}
Expand Down Expand Up @@ -256,15 +285,16 @@ class JsonBoolean : public Value {
};

struct StringView {
char const* str_;
using CharT = char; // unsigned char
CharT const* str_;
size_t size_;

public:
StringView() = default;
StringView(char const* str, size_t size) : str_{str}, size_{size} {}
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}

char const& operator[](size_t p) const { return str_[p]; }
char const& at(size_t p) const { // NOLINT
CharT const& operator[](size_t p) const { return str_[p]; }
CharT const& at(size_t p) const { // NOLINT
CHECK_LT(p, size_);
return str_[p];
}
Expand Down Expand Up @@ -319,6 +349,13 @@ class Json {
return *this;
}

// integer
explicit Json(JsonInteger integer) : ptr_{new JsonInteger(integer)} {}
Json& operator=(JsonInteger integer) {
ptr_.reset(new JsonInteger(std::move(integer)));
return *this;
}

// array
explicit Json(JsonArray list) :
ptr_ {new JsonArray(std::move(list))} {}
Expand Down Expand Up @@ -410,10 +447,24 @@ JsonNumber::Float& GetImpl(T& val) { // NOLINT
template <typename T,
typename std::enable_if<
std::is_same<T, JsonNumber const>::value>::type* = nullptr>
double const& GetImpl(T& val) { // NOLINT
JsonNumber::Float const& GetImpl(T& val) { // NOLINT
return val.getNumber();
}

// Integer
template <typename T,
typename std::enable_if<
std::is_same<T, JsonInteger>::value>::type* = nullptr>
JsonInteger::Int& GetImpl(T& val) { // NOLINT
return val.getInteger();
}
template <typename T,
typename std::enable_if<
std::is_same<T, JsonInteger const>::value>::type* = nullptr>
JsonInteger::Int const& GetImpl(T& val) { // NOLINT
return val.getInteger();
}

// String
template <typename T,
typename std::enable_if<
Expand Down Expand Up @@ -502,6 +553,7 @@ auto get(U& json) -> decltype(detail::GetImpl(*Cast<T>(&json.GetValue())))& { //
using Object = JsonObject;
using Array = JsonArray;
using Number = JsonNumber;
using Integer = JsonInteger;
using Boolean = JsonBoolean;
using String = JsonString;
using Null = JsonNull;
Expand All @@ -525,6 +577,36 @@ inline std::map<std::string, std::string> fromJson(std::map<std::string, Json> c
}
return res;
}

} // namespace xgboost

#include <rabit/rabit.h>

namespace xgboost {

struct Serializable : public rabit::Serializable {
virtual ~Serializable() = default;
/*!
* \brief load the model from a stream
* \param fi stream where to load the model from
*/
virtual void Load(dmlc::Stream *fi) override = 0;
/*!
* \brief saves the model to a stream
* \param fo stream where to save the model to
*/
virtual void Save(dmlc::Stream *fo) const override = 0;

/*!
* \brief load the model from a json object
* \param in json object where to load the model from
*/
virtual void Load(Json const& in) = 0;
/*!
* \breif saves the model to a json object
* \param out json container where to save the model to
*/
virtual void Save(Json* out) const = 0;
};
} // namespace xgboost

#endif // XGBOOST_JSON_H_
7 changes: 6 additions & 1 deletion include/xgboost/json_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ class JsonReader {
explicit SourceLocation(size_t pos) : pos_{pos} {}
size_t Pos() const { return pos_; }

SourceLocation& Forward(char c = 0) {
SourceLocation& Forward() {
pos_++;
return *this;
}
SourceLocation& Forward(uint32_t n) {
pos_ += n;
return *this;
}
} cursor_;

StringView raw_str_;
Expand Down Expand Up @@ -207,6 +211,7 @@ class JsonWriter {
virtual void Visit(JsonArray const* arr);
virtual void Visit(JsonObject const* obj);
virtual void Visit(JsonNumber const* num);
virtual void Visit(JsonInteger const* num);
virtual void Visit(JsonRaw const* raw);
virtual void Visit(JsonNull const* null);
virtual void Visit(JsonString const* str);
Expand Down
8 changes: 3 additions & 5 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <rabit/rabit.h>

#include <xgboost/base.h>
#include <xgboost/gbm.h>
#include <xgboost/metric.h>
Expand Down Expand Up @@ -42,7 +40,7 @@ namespace xgboost {
*
* \endcode
*/
class Learner : public rabit::Serializable {
class Learner : public Serializable {
public:
/*! \brief virtual destructor */
~Learner() override = default;
Expand All @@ -51,14 +49,14 @@ class Learner : public rabit::Serializable {
*/
virtual void Configure() = 0;

virtual void Load(Json const& in) = 0;
virtual void Load(Json const& in) override = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
void Load(dmlc::Stream* fi) override = 0;

virtual void Save(Json* out) const = 0;
virtual void Save(Json* out) const override = 0;
/*!
* \brief save model to stream.
* \param fo output stream
Expand Down
Loading

0 comments on commit 317e124

Please sign in to comment.