diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 11f5299c0b67..7ed12c704a9f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable { // handle to the booster. private long handle = 0; private int version = 0; + /** + * Type of prediction, used for inplace_predict. + */ + public enum PredictionType { + kValue(0), + kMargin(1); + + private Integer ptype; + private PredictionType(final Integer ptype) { + this.ptype = ptype; + } + public Integer getPType() { + return ptype; + } + } /** * Create a new Booster with empty stage. @@ -375,6 +390,97 @@ private synchronized float[][] predict(DMatrix data, return predicts; } + /** + * Perform thread-safe prediction. + * + * @param data Flattened input matrix of features for prediction + * @param nrow The number of preditions to make (count of input matrix rows) + * @param ncol The number of features in the model (count of input matrix columns) + * @param missing Value indicating missing element in the data input matrix + * + * @return predict Result matrix + */ + public float[][] inplace_predict(float[] data, + int nrow, + int ncol, + float missing) throws XGBoostError { + int[] iteration_range = new int[2]; + iteration_range[0] = 0; + iteration_range[1] = 0; + return this.inplace_predict(data, nrow, ncol, + missing, iteration_range, PredictionType.kValue, null); + } + + /** + * Perform thread-safe prediction. + * + * @param data Flattened input matrix of features for prediction + * @param nrow The number of preditions to make (count of input matrix rows) + * @param ncol The number of features in the model (count of input matrix columns) + * @param missing Value indicating missing element in the data input matrix + * @param iteration_range Specifies which layer of trees are used in prediction. For + * example, if a random forest is trained with 100 rounds. + * Specifying `iteration_range=[10, 20)`, then only the forests + * built during [10, 20) (half open set) rounds are used in this + * prediction. + * + * @return predict Result matrix + */ + public float[][] inplace_predict(float[] data, + int nrow, + int ncol, + float missing, int[] iteration_range) throws XGBoostError { + return this.inplace_predict(data, nrow, ncol, + missing, iteration_range, PredictionType.kValue, null); + } + + + /** + * Perform thread-safe prediction. + * + * @param data Flattened input matrix of features for prediction + * @param nrow The number of preditions to make (count of input matrix rows) + * @param ncol The number of features in the model (count of input matrix columns) + * @param missing Value indicating missing element in the data input matrix + * @param iteration_range Specifies which layer of trees are used in prediction. For + * example, if a random forest is trained with 100 rounds. + * Specifying `iteration_range=[10, 20)`, then only the forests + * built during [10, 20) (half open set) rounds are used in this + * prediction. + * @param predict_type What kind of prediction to run. + * @return predict Result matrix + */ + public float[][] inplace_predict(float[] data, + int nrow, + int ncol, + float missing, + int[] iteration_range, + PredictionType predict_type, + float[] base_margin) throws XGBoostError { + if (iteration_range.length != 2) { + throw new XGBoostError(new String("Iteration range is expected to be [begin, end).")); + } + int ptype = predict_type.getPType(); + + int begin = iteration_range[0]; + int end = iteration_range[1]; + + float[][] rawPredicts = new float[1][]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol, + missing, + begin, end, ptype, base_margin, rawPredicts)); + + int col = rawPredicts[0].length / nrow; + float[][] predicts = new float[nrow][col]; + int r, c; + for (int i = 0; i < rawPredicts[0].length; i++) { + r = i / col; + c = i % col; + predicts[r][c] = rawPredicts[0][i]; + } + return predicts; + } + /** * Predict leaf indices given the data * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index d71d0a4f5c81..eabbf29ba945 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -119,6 +119,10 @@ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[ public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts); + public final static native int XGBoosterPredictFromDense(long handle, float[] data, + long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin, + float[][] predicts); + public final static native int XGBoosterLoadModel(long handle, String fname); public final static native int XGBoosterSaveModel(long handle, String fname); @@ -154,10 +158,6 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count, public final static native int XGDMatrixSetInfoFromInterface( long handle, String field, String json); - @Deprecated - public final static native int XGDeviceQuantileDMatrixCreateFromCallback( - java.util.Iterator iter, float missing, int nthread, int maxBin, long[] out); - public final static native int XGQuantileDMatrixCreateFromCallback( java.util.Iterator iter, java.util.Iterator ref, String config, long[] out); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index c0c0774308fa..821b1ebff054 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict return ret; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterPredictFromDense + * Signature: (J[FJJFIII[F[[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense( + JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features, + jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type, + jfloatArray jmargin, jobjectArray jout) { + API_BEGIN(); + BoosterHandle handle = reinterpret_cast(jhandle); + + /** + * Create array interface. + */ + namespace linalg = xgboost::linalg; + jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr); + xgboost::Context ctx; + auto t_data = linalg::MakeTensorView( + ctx.Device(), + xgboost::common::Span{data, static_cast(num_rows * num_features)}, num_rows, + num_features); + auto s_array = linalg::ArrayInterfaceStr(t_data); + + /** + * Create configuration object. + */ + xgboost::Json config{xgboost::Object{}}; + config["cache_id"] = xgboost::Integer{}; + config["type"] = xgboost::Integer{static_cast(predict_type)}; + config["iteration_begin"] = xgboost::Integer{static_cast(iteration_begin)}; + config["iteration_end"] = xgboost::Integer{static_cast(iteration_end)}; + config["missing"] = xgboost::Number{static_cast(missing)}; + config["strict_shape"] = xgboost::Boolean{true}; + std::string s_config; + xgboost::Json::Dump(config, &s_config); + + /** + * Handle base margin + */ + BoosterHandle proxy{nullptr}; + + float *margin{nullptr}; + if (jmargin) { + margin = jenv->GetFloatArrayElements(jmargin, nullptr); + JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy)); + JVM_CHECK_CALL( + XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin))); + } + + bst_ulong const *out_shape; + bst_ulong out_dim; + float const *result; + auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape, + &out_dim, &result); + + jenv->ReleaseFloatArrayElements(jdata, data, 0); + if (proxy) { + XGDMatrixFree(proxy); + jenv->ReleaseFloatArrayElements(jmargin, margin, 0); + } + + if (ret != 0) { + return ret; + } + + std::size_t n{1}; + for (std::size_t i = 0; i < out_dim; ++i) { + n *= out_shape[i]; + } + + jfloatArray jarray = jenv->NewFloatArray(n); + + jenv->SetFloatArrayRegion(jarray, 0, n, result); + jenv->SetObjectArrayElement(jout, 0, jarray); + + API_END(); +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterLoadModel diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index b221c6a57da7..87ff6d30db6a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict (JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterPredictFromDense + * Signature: (J[FJJFIII[F[[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense + (JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterLoadModel @@ -359,14 +367,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface (JNIEnv *, jclass, jlong, jstring, jstring); -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGDeviceQuantileDMatrixCreateFromCallback - * Signature: (Ljava/util/Iterator;FII[J)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback - (JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray); - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGQuantileDMatrixCreateFromCallback diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 70966a38f580..c7508b20d8ea 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -15,16 +15,24 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; -import java.util.*; - import junit.framework.TestCase; +import org.junit.Assert; import org.junit.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.fail; + /** - * test cases for Booster - * - * @author hzx + * test cases for Booster Inplace Predict + * + * @author hzx and Sovrn */ public class BoosterImplTest { private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm"; @@ -99,6 +107,179 @@ public void testBoosterBasic() throws XGBoostError, IOException { TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); } + @Test + public void inplacePredictTest() throws XGBoostError { + /* Data Generation */ + // Generate a training set. + int trainRows = 1000; + int features = 10; + int trainSize = trainRows * features; + float[] trainX = generateRandomDataSet(trainSize); + float[] trainY = generateRandomDataSet(trainRows); + + DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN); + trainingMatrix.setLabel(trainY); + + // Generate a testing set + int testRows = 10; + int testSize = testRows * features; + float[] testX = generateRandomDataSet(testSize); + float[] testY = generateRandomDataSet(testRows); + + DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN); + testingMatrix.setLabel(testY); + + /* Training */ + + // Set parameters + Map params = new HashMap<>(); + params.put("eta", 1.0); + params.put("max_depth",2); + params.put("silent", 1); + params.put("tree_method", "hist"); + + Map watches = new HashMap<>(); + watches.put("train", trainingMatrix); + watches.put("test", testingMatrix); + + Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null); + + /* Prediction */ + + // Standard prediction + float[][] predictions = booster.predict(testingMatrix); + + // Inplace-prediction + float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN); + + // Confirm that the two prediction results are identical + assertArrayEquals(predictions, inplacePredictions); + } + + @Test + public void inplacePredictMultiPredictTest() throws InterruptedException { + // Multithreaded, multiple prediction + int trainRows = 1000; + int features = 10; + int trainSize = trainRows * features; + + int testRows = 10; + int testSize = testRows * features; + + //Simulate multiple predictions on multiple random data sets simultaneously. + ExecutorService executorService = Executors.newFixedThreadPool(5); + int predictsToPerform = 100; + for(int i = 0; i < predictsToPerform; i++) { + executorService.submit(() -> { + try { + float[] trainX = generateRandomDataSet(trainSize); + float[] trainY = generateRandomDataSet(trainRows); + DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN); + trainingMatrix.setLabel(trainY); + + float[] testX = generateRandomDataSet(testSize); + float[] testY = generateRandomDataSet(testRows); + DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN); + testingMatrix.setLabel(testY); + + Map params = new HashMap<>(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("tree_method", "hist"); + + Map watches = new HashMap<>(); + watches.put("train", trainingMatrix); + watches.put("test", testingMatrix); + + Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null); + + float[][] predictions = booster.predict(testingMatrix); + float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN); + + assertArrayEquals(predictions, inplacePredictions); + } catch (XGBoostError xgBoostError) { + fail(xgBoostError.getMessage()); + } + }); + } + executorService.shutdown(); + if(!executorService.awaitTermination(1, TimeUnit.MINUTES)) + executorService.shutdownNow(); + } + + @Test + public void inplacePredictWithMarginTest() throws XGBoostError { + //Generate a training set + int trainRows = 1000; + int features = 10; + int trainSize = trainRows * features; + float[] trainX = generateRandomDataSet(trainSize); + float[] trainY = generateRandomDataSet(trainRows); + + DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN); + trainingMatrix.setLabel(trainY); + + // Generate a testing set + int testRows = 10; + int testSize = testRows * features; + float[] testX = generateRandomDataSet(testSize); + float[] testY = generateRandomDataSet(testRows); + + DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN); + testingMatrix.setLabel(testY); + + // Set booster parameters + Map params = new HashMap<>(); + params.put("eta", 1.0); + params.put("max_depth",2); + params.put("tree_method", "hist"); + params.put("base_score", 0.0); + + Map watches = new HashMap<>(); + watches.put("train", trainingMatrix); + watches.put("test", testingMatrix); + + // Train booster on training matrix. + Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null); + + // Create a margin + float[] margin = new float[testRows]; + Arrays.fill(margin, 0.5f); + + // Define an iteration range to use all training iterations, this should match + // the without margin call + // which defines an iteration range of [0,0) + int[] iterationRange = new int[] { 0, 0 }; + + float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX, + testRows, + features, + Float.NaN, + iterationRange, + Booster.PredictionType.kValue, + margin); + float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN); + + for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) { + for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) { + inplacePredictionsWithoutMargin[i][j] += margin[j]; + } + } + for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) { + assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f); + } + } + + private float[] generateRandomDataSet(int size) { + float[] newSet = new float[size]; + Random random = new Random(); + for(int i = 0; i < size; i++) { + newSet[i] = random.nextFloat(); + } + return newSet; + } + @Test public void saveLoadModelWithPath() throws XGBoostError, IOException { DMatrix trainMat = new DMatrix(this.train_uri);