Skip to content

Commit

Permalink
Basic R implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 15, 2019
1 parent d86c418 commit 528c1be
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 17 deletions.
3 changes: 3 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ S3method(slice,xgb.DMatrix)
export("xgb.attr<-")
export("xgb.attributes<-")
export("xgb.parameters<-")
export("xgb.config<-")
export(cb.cv.predict)
export(cb.early.stop)
export(cb.evaluation.log)
Expand All @@ -29,6 +30,7 @@ export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.config)
export(xgb.attributes)
export(xgb.create.features)
export(xgb.cv)
Expand All @@ -38,6 +40,7 @@ export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.importance)
export(xgb.load)
export(xgb.load.raw)
export(xgb.model.dt.tree)
export(xgb.plot.deepness)
export(xgb.plot.importance)
Expand Down
43 changes: 39 additions & 4 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,34 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), modelfile =
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects")
}

handle <- .Call(XGBoosterCreate_R, cachelist)
## Load existing model, dispatch for on disk model file and in memory buffer
if (!is.null(modelfile)) {
if (typeof(modelfile) == "character") {
## A filename
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModel_R, handle, modelfile[1])
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
}
return(handle)
} else if (typeof(modelfile) == "raw") {
.Call(XGBoosterLoadModelFromRaw_R, handle, modelfile)
## A memory buffer
bst <- xgb.load.raw(modelfile)
xgb.parameters(bst) <- params
return (bst)
} else if (inherits(modelfile, "xgb.Booster")) {
## A booster object
bst <- xgb.Booster.complete(modelfile, saveraw = TRUE)
.Call(XGBoosterLoadModelFromRaw_R, handle, bst$raw)
bst <- xgb.load.raw(bst$raw)
xgb.parameters(bst) <- params
return (bst)
} else {
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
}
}
## Create new model
handle <- .Call(XGBoosterCreate_R, cachelist)
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
Expand Down Expand Up @@ -500,6 +514,27 @@ xgb.attributes <- function(object) {
object
}

#' Accessors for model parameters as JSON string.
#'
#' @param object Object of class \code{xgb.Booster}
#' @param value A JSON string.
#'
#' @rdname xgb.config
#' @export
xgb.config <- function(object) {
handle <- xgb.get.handle(object)
.Call(XGBoosterSaveJsonConfig_R, handle);
}

#' @rdname xgb.config
#' @export
`xgb.config<-` <- function(object, value) {
handle <- xgb.get.handle(object)
.Call(XGBoosterLoadJsonConfig_R, handle, value)
object$raw <- xgb.save.raw(object)
object
}

#' Accessors for model parameters.
#'
#' Only the setter for xgboost parameters is currently implemented.
Expand Down
16 changes: 16 additions & 0 deletions R-package/R/xgb.load.raw.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#' Load serialised xgboost model from R's raw vector
#'
#' User can generate raw memory buffer by calling xgb.save.raw
#'
#' @param buffer the buffer returned by xgb.save.raw
#'
#' @export
xgb.load.raw <- function(buffer) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
modellst <- unserialize(buffer)
.Call(XGBoosterLoadModelFromRaw_R, handle, modellst$model)
.Call(XGBoosterLoadJsonConfig_R, handle, modellst$config)
class(handle) <- "xgb.Booster.handle"
return (handle)
}
18 changes: 10 additions & 8 deletions R-package/R/xgb.save.raw.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
#' Save xgboost model to R's raw vector,
#' user can call xgb.load to load the model back from raw vector
#'
#' user can call xgb.load.raw to load the model back from raw vector
#'
#' Save xgboost model from xgboost or xgb.train
#'
#'
#' @param model the model object.
#'
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
#' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' raw <- xgb.save.raw(bst)
#' bst <- xgb.load(raw)
#' bst <- xgb.load.raw(raw)
#' pred <- predict(bst, test$data)
#'
#' @export
xgb.save.raw <- function(model) {
model <- xgb.get.handle(model)
.Call(XGBoosterModelToRaw_R, model)
handle <- xgb.get.handle(model)
model <- .Call(XGBoosterModelToRaw_R, handle)
config <- .Call(XGBoosterSaveJsonConfig_R, handle)
serialize(list(model=model, config=config), NULL)
}
4 changes: 4 additions & 0 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ extern SEXP XGBoosterGetAttrNames_R(SEXP);
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterModelToRaw_R(SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
Expand All @@ -49,6 +51,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 4},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
Expand Down
18 changes: 18 additions & 0 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,24 @@ SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
return R_NilValue;
}

SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
const char* ret;
R_API_BEGIN();
bst_ulong len {0};
CHECK_CALL(XGBoosterSaveJsonParameters(R_ExternalPtrAddr(handle),
&len,
&ret));
R_API_END();
return mkString(ret);
}

SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
R_API_BEGIN();
XGBoosterLoadJsonParameters(R_ExternalPtrAddr(handle), CHAR(asChar(value)));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterModelToRaw_R(SEXP handle) {
SEXP ret;
R_API_BEGIN();
Expand Down
16 changes: 15 additions & 1 deletion R-package/src/xgboost_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,23 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
* \brief save model into R's raw array
* \param handle handle
* \return raw array
*/
*/
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);

/*!
* \brief Save internal parameters as a JSON string
* \param handle handle
* \return JSON string
*/
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
/*!
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
* \param handle handle
* \param value JSON string
* \return R_NilValue
*/
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);

/*!
* \brief dump model into a string
* \param handle handle
Expand Down
4 changes: 3 additions & 1 deletion R-package/tests/testthat/test_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,17 @@ test_that("cb.save.model works as expected", {
expect_equal(xgb.ntree(b1), 1)
b2 <- xgb.load('xgboost_02.model')
expect_equal(xgb.ntree(b2), 2)
xgb.config(b2) <- xgb.config(bst)
expect_equal(bst$raw, b2$raw)

# save_period = 0 saves the last iteration's model
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
save_period = 0)
expect_true(file.exists('xgboost.model'))
b2 <- xgb.load('xgboost.model')
xgb.config(b2) <- xgb.config(bst)
expect_equal(bst$raw, b2$raw)

for (f in files) if (file.exists(f)) file.remove(f)
})

Expand Down
2 changes: 0 additions & 2 deletions R-package/tests/testthat/test_custom_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ num_round <- 2
test_that("custom objective works", {
bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1100)
expect_false(is.null(bst$evaluation_log))
expect_false(is.null(bst$evaluation_log$eval_error))
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
Expand All @@ -58,5 +57,4 @@ test_that("custom objective using DMatrix attr works", {
param$objective = logregobjattr
bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1100)
})
3 changes: 2 additions & 1 deletion include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ class RegTree : public Model {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
CHECK(node.IsLeaf());
CHECK(node.IsLeaf()) << "Node ID: " << nid << ", "
<< "Left child mark: " << node.LeftChild();
node.SetLeftChild(pleft);
node.SetRightChild(pright);
nodes_[node.LeftChild()].SetParent(nid, true);
Expand Down

0 comments on commit 528c1be

Please sign in to comment.