diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
index 11f5299c0b67..7ed12c704a9f 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
@@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
// handle to the booster.
private long handle = 0;
private int version = 0;
+ /**
+ * Type of prediction, used for inplace_predict.
+ */
+ public enum PredictionType {
+ kValue(0),
+ kMargin(1);
+
+ private Integer ptype;
+ private PredictionType(final Integer ptype) {
+ this.ptype = ptype;
+ }
+ public Integer getPType() {
+ return ptype;
+ }
+ }
/**
* Create a new Booster with empty stage.
@@ -375,6 +390,97 @@ private synchronized float[][] predict(DMatrix data,
return predicts;
}
+ /**
+ * Perform thread-safe prediction.
+ *
+ * @param data Flattened input matrix of features for prediction
+ * @param nrow The number of preditions to make (count of input matrix rows)
+ * @param ncol The number of features in the model (count of input matrix columns)
+ * @param missing Value indicating missing element in the data
input matrix
+ *
+ * @return predict Result matrix
+ */
+ public float[][] inplace_predict(float[] data,
+ int nrow,
+ int ncol,
+ float missing) throws XGBoostError {
+ int[] iteration_range = new int[2];
+ iteration_range[0] = 0;
+ iteration_range[1] = 0;
+ return this.inplace_predict(data, nrow, ncol,
+ missing, iteration_range, PredictionType.kValue, null);
+ }
+
+ /**
+ * Perform thread-safe prediction.
+ *
+ * @param data Flattened input matrix of features for prediction
+ * @param nrow The number of preditions to make (count of input matrix rows)
+ * @param ncol The number of features in the model (count of input matrix columns)
+ * @param missing Value indicating missing element in the data
input matrix
+ * @param iteration_range Specifies which layer of trees are used in prediction. For
+ * example, if a random forest is trained with 100 rounds.
+ * Specifying `iteration_range=[10, 20)`, then only the forests
+ * built during [10, 20) (half open set) rounds are used in this
+ * prediction.
+ *
+ * @return predict Result matrix
+ */
+ public float[][] inplace_predict(float[] data,
+ int nrow,
+ int ncol,
+ float missing, int[] iteration_range) throws XGBoostError {
+ return this.inplace_predict(data, nrow, ncol,
+ missing, iteration_range, PredictionType.kValue, null);
+ }
+
+
+ /**
+ * Perform thread-safe prediction.
+ *
+ * @param data Flattened input matrix of features for prediction
+ * @param nrow The number of preditions to make (count of input matrix rows)
+ * @param ncol The number of features in the model (count of input matrix columns)
+ * @param missing Value indicating missing element in the data
input matrix
+ * @param iteration_range Specifies which layer of trees are used in prediction. For
+ * example, if a random forest is trained with 100 rounds.
+ * Specifying `iteration_range=[10, 20)`, then only the forests
+ * built during [10, 20) (half open set) rounds are used in this
+ * prediction.
+ * @param predict_type What kind of prediction to run.
+ * @return predict Result matrix
+ */
+ public float[][] inplace_predict(float[] data,
+ int nrow,
+ int ncol,
+ float missing,
+ int[] iteration_range,
+ PredictionType predict_type,
+ float[] base_margin) throws XGBoostError {
+ if (iteration_range.length != 2) {
+ throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
+ }
+ int ptype = predict_type.getPType();
+
+ int begin = iteration_range[0];
+ int end = iteration_range[1];
+
+ float[][] rawPredicts = new float[1][];
+ XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
+ missing,
+ begin, end, ptype, base_margin, rawPredicts));
+
+ int col = rawPredicts[0].length / nrow;
+ float[][] predicts = new float[nrow][col];
+ int r, c;
+ for (int i = 0; i < rawPredicts[0].length; i++) {
+ r = i / col;
+ c = i % col;
+ predicts[r][c] = rawPredicts[0][i];
+ }
+ return predicts;
+ }
+
/**
* Predict leaf indices given the data
*
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
index d71d0a4f5c81..eabbf29ba945 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
@@ -119,6 +119,10 @@ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);
+ public final static native int XGBoosterPredictFromDense(long handle, float[] data,
+ long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
+ float[][] predicts);
+
public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native int XGBoosterSaveModel(long handle, String fname);
@@ -154,10 +158,6 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);
- @Deprecated
- public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
- java.util.Iterator iter, float missing, int nthread, int maxBin, long[] out);
-
public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator iter, java.util.Iterator ref, String config, long[] out);
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
index c0c0774308fa..821b1ebff054 100644
--- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}
+/*
+ * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
+ * Method: XGBoosterPredictFromDense
+ * Signature: (J[FJJFIII[F[[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
+ JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
+ jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
+ jfloatArray jmargin, jobjectArray jout) {
+ API_BEGIN();
+ BoosterHandle handle = reinterpret_cast(jhandle);
+
+ /**
+ * Create array interface.
+ */
+ namespace linalg = xgboost::linalg;
+ jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
+ xgboost::Context ctx;
+ auto t_data = linalg::MakeTensorView(
+ ctx.Device(),
+ xgboost::common::Span{data, static_cast(num_rows * num_features)}, num_rows,
+ num_features);
+ auto s_array = linalg::ArrayInterfaceStr(t_data);
+
+ /**
+ * Create configuration object.
+ */
+ xgboost::Json config{xgboost::Object{}};
+ config["cache_id"] = xgboost::Integer{};
+ config["type"] = xgboost::Integer{static_cast(predict_type)};
+ config["iteration_begin"] = xgboost::Integer{static_cast(iteration_begin)};
+ config["iteration_end"] = xgboost::Integer{static_cast(iteration_end)};
+ config["missing"] = xgboost::Number{static_cast(missing)};
+ config["strict_shape"] = xgboost::Boolean{true};
+ std::string s_config;
+ xgboost::Json::Dump(config, &s_config);
+
+ /**
+ * Handle base margin
+ */
+ BoosterHandle proxy{nullptr};
+
+ float *margin{nullptr};
+ if (jmargin) {
+ margin = jenv->GetFloatArrayElements(jmargin, nullptr);
+ JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
+ JVM_CHECK_CALL(
+ XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
+ }
+
+ bst_ulong const *out_shape;
+ bst_ulong out_dim;
+ float const *result;
+ auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
+ &out_dim, &result);
+
+ jenv->ReleaseFloatArrayElements(jdata, data, 0);
+ if (proxy) {
+ XGDMatrixFree(proxy);
+ jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
+ }
+
+ if (ret != 0) {
+ return ret;
+ }
+
+ std::size_t n{1};
+ for (std::size_t i = 0; i < out_dim; ++i) {
+ n *= out_shape[i];
+ }
+
+ jfloatArray jarray = jenv->NewFloatArray(n);
+
+ jenv->SetFloatArrayRegion(jarray, 0, n, result);
+ jenv->SetObjectArrayElement(jout, 0, jarray);
+
+ API_END();
+}
+
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h
index b221c6a57da7..87ff6d30db6a 100644
--- a/jvm-packages/xgboost4j/src/native/xgboost4j.h
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h
@@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
+/*
+ * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
+ * Method: XGBoosterPredictFromDense
+ * Signature: (J[FJJFIII[F[[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
+ (JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
+
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
@@ -359,14 +367,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
-/*
- * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
- * Method: XGDeviceQuantileDMatrixCreateFromCallback
- * Signature: (Ljava/util/Iterator;FII[J)I
- */
-JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
- (JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
-
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
index 70966a38f580..c7508b20d8ea 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
@@ -15,16 +15,24 @@
*/
package ml.dmlc.xgboost4j.java;
-import java.io.*;
-import java.util.*;
-
import junit.framework.TestCase;
+import org.junit.Assert;
import org.junit.Test;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+import java.util.concurrent.*;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.fail;
+
/**
- * test cases for Booster
- *
- * @author hzx
+ * test cases for Booster Inplace Predict
+ *
+ * @author hzx and Sovrn
*/
public class BoosterImplTest {
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
@@ -99,6 +107,179 @@ public void testBoosterBasic() throws XGBoostError, IOException {
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
}
+ @Test
+ public void inplacePredictTest() throws XGBoostError {
+ /* Data Generation */
+ // Generate a training set.
+ int trainRows = 1000;
+ int features = 10;
+ int trainSize = trainRows * features;
+ float[] trainX = generateRandomDataSet(trainSize);
+ float[] trainY = generateRandomDataSet(trainRows);
+
+ DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
+ trainingMatrix.setLabel(trainY);
+
+ // Generate a testing set
+ int testRows = 10;
+ int testSize = testRows * features;
+ float[] testX = generateRandomDataSet(testSize);
+ float[] testY = generateRandomDataSet(testRows);
+
+ DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
+ testingMatrix.setLabel(testY);
+
+ /* Training */
+
+ // Set parameters
+ Map params = new HashMap<>();
+ params.put("eta", 1.0);
+ params.put("max_depth",2);
+ params.put("silent", 1);
+ params.put("tree_method", "hist");
+
+ Map watches = new HashMap<>();
+ watches.put("train", trainingMatrix);
+ watches.put("test", testingMatrix);
+
+ Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
+
+ /* Prediction */
+
+ // Standard prediction
+ float[][] predictions = booster.predict(testingMatrix);
+
+ // Inplace-prediction
+ float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
+
+ // Confirm that the two prediction results are identical
+ assertArrayEquals(predictions, inplacePredictions);
+ }
+
+ @Test
+ public void inplacePredictMultiPredictTest() throws InterruptedException {
+ // Multithreaded, multiple prediction
+ int trainRows = 1000;
+ int features = 10;
+ int trainSize = trainRows * features;
+
+ int testRows = 10;
+ int testSize = testRows * features;
+
+ //Simulate multiple predictions on multiple random data sets simultaneously.
+ ExecutorService executorService = Executors.newFixedThreadPool(5);
+ int predictsToPerform = 100;
+ for(int i = 0; i < predictsToPerform; i++) {
+ executorService.submit(() -> {
+ try {
+ float[] trainX = generateRandomDataSet(trainSize);
+ float[] trainY = generateRandomDataSet(trainRows);
+ DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
+ trainingMatrix.setLabel(trainY);
+
+ float[] testX = generateRandomDataSet(testSize);
+ float[] testY = generateRandomDataSet(testRows);
+ DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
+ testingMatrix.setLabel(testY);
+
+ Map params = new HashMap<>();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("tree_method", "hist");
+
+ Map watches = new HashMap<>();
+ watches.put("train", trainingMatrix);
+ watches.put("test", testingMatrix);
+
+ Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
+
+ float[][] predictions = booster.predict(testingMatrix);
+ float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
+
+ assertArrayEquals(predictions, inplacePredictions);
+ } catch (XGBoostError xgBoostError) {
+ fail(xgBoostError.getMessage());
+ }
+ });
+ }
+ executorService.shutdown();
+ if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
+ executorService.shutdownNow();
+ }
+
+ @Test
+ public void inplacePredictWithMarginTest() throws XGBoostError {
+ //Generate a training set
+ int trainRows = 1000;
+ int features = 10;
+ int trainSize = trainRows * features;
+ float[] trainX = generateRandomDataSet(trainSize);
+ float[] trainY = generateRandomDataSet(trainRows);
+
+ DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
+ trainingMatrix.setLabel(trainY);
+
+ // Generate a testing set
+ int testRows = 10;
+ int testSize = testRows * features;
+ float[] testX = generateRandomDataSet(testSize);
+ float[] testY = generateRandomDataSet(testRows);
+
+ DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
+ testingMatrix.setLabel(testY);
+
+ // Set booster parameters
+ Map params = new HashMap<>();
+ params.put("eta", 1.0);
+ params.put("max_depth",2);
+ params.put("tree_method", "hist");
+ params.put("base_score", 0.0);
+
+ Map watches = new HashMap<>();
+ watches.put("train", trainingMatrix);
+ watches.put("test", testingMatrix);
+
+ // Train booster on training matrix.
+ Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
+
+ // Create a margin
+ float[] margin = new float[testRows];
+ Arrays.fill(margin, 0.5f);
+
+ // Define an iteration range to use all training iterations, this should match
+ // the without margin call
+ // which defines an iteration range of [0,0)
+ int[] iterationRange = new int[] { 0, 0 };
+
+ float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
+ testRows,
+ features,
+ Float.NaN,
+ iterationRange,
+ Booster.PredictionType.kValue,
+ margin);
+ float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
+
+ for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
+ for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
+ inplacePredictionsWithoutMargin[i][j] += margin[j];
+ }
+ }
+ for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
+ assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
+ }
+ }
+
+ private float[] generateRandomDataSet(int size) {
+ float[] newSet = new float[size];
+ Random random = new Random();
+ for(int i = 0; i < size; i++) {
+ newSet[i] = random.nextFloat();
+ }
+ return newSet;
+ }
+
@Test
public void saveLoadModelWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);