Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow JVM-Package to access inplace predict method #9167

Merged
merged 24 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5d448e
Logging in cpu predictor (+23 squashed commits)
StephanTLavavej May 12, 2023
ee0a029
Clean up some comments (+5 squashed commits)
yoquinjo Mar 23, 2023
5ce8d65
Additional documentation (+3 squashed commits)
yoquinjo May 17, 2023
058aec3
Formatting
yoquinjo May 18, 2023
79ce356
Adjust boosterimpltest import statments
yoquinjo May 18, 2023
86e38d4
Mimic assertion params from DMatrixTest
yoquinjo May 19, 2023
0c5b2df
Clean up some comments in BoosterImplTest
yoquinjo May 23, 2023
312c874
update attribution of authorship
yoquinjo May 23, 2023
bed3035
Update BoosterImplTest.java format
yoquinjo Jun 1, 2023
57d88da
Start working on inplace prediction.
trivialfis Aug 7, 2023
b923c95
Replace the implementation.
trivialfis Aug 7, 2023
3b88aaf
jni
trivialfis Aug 7, 2023
e132136
test.
trivialfis Aug 7, 2023
d623163
cleanup.
trivialfis Aug 7, 2023
e4f5ef8
win.
trivialfis Aug 7, 2023
bccd539
cleanup.
trivialfis Aug 9, 2023
8b4885d
Merge pull request #9 from trivialfis/jvm-inplace-predict
yoquinjo Aug 21, 2023
30ed3ca
Merge branch 'master' into sovrn-inplace-predict-java
yoquinjo Aug 22, 2023
509c880
EXDS-35- Cleaned up inplace_predict test. Refactored into two separat…
ByteSizedJoe Aug 30, 2023
272ed03
EXDS-35- Modified iteration range to use [0,10] since the number of r…
ByteSizedJoe Aug 30, 2023
cc8e79b
EXDS-35 - Added a few clarifying comments, explicitly defined iterati…
ByteSizedJoe Aug 31, 2023
81c23c2
EXDS-35 - Brought in trivial from dmlc's changes which include settin…
ByteSizedJoe Aug 31, 2023
dec1375
Merge remote-tracking branch 'dmlc/master' into sovrn-inplace-predict…
ByteSizedJoe Sep 1, 2023
e91026a
Change bst_d_ordinal_t kCpuId to use DeviceOrd
ByteSizedJoe Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
// handle to the booster.
private long handle = 0;
private int version = 0;
/**
* Type of prediction, used for inplace_predict.
*/
public enum PredictionType {
kValue(0),
kMargin(1);

private Integer ptype;
private PredictionType(final Integer ptype) {
this.ptype = ptype;
}
public Integer getPType() {
return ptype;
}
}

/**
* Create a new Booster with empty stage.
Expand Down Expand Up @@ -375,6 +390,97 @@ private synchronized float[][] predict(DMatrix data,
return predicts;
}

/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wbo4958 I just learned about the existence of BigDenseMatrix, do you think we should return the prediction in that?

int nrow,
int ncol,
float missing) throws XGBoostError {
int[] iteration_range = new int[2];
iteration_range[0] = 0;
iteration_range[1] = 0;
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}

/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing, int[] iteration_range) throws XGBoostError {
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}


/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
* @param predict_type What kind of prediction to run.
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing,
int[] iteration_range,
PredictionType predict_type,
float[] base_margin) throws XGBoostError {
if (iteration_range.length != 2) {
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
}
int ptype = predict_type.getPType();

int begin = iteration_range[0];
int end = iteration_range[1];

float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
missing,
begin, end, ptype, base_margin, rawPredicts));

int col = rawPredicts[0].length / nrow;
float[][] predicts = new float[nrow][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}

/**
* Predict leaf indices given the data
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);

public final static native int XGBoosterPredictFromDense(long handle, float[] data,
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
float[][] predicts);

public final static native int XGBoosterLoadModel(long handle, String fname);

public final static native int XGBoosterSaveModel(long handle, String fname);
Expand Down Expand Up @@ -154,10 +158,6 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

@Deprecated
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);

public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);

Expand Down
79 changes: 79 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
jfloatArray jmargin, jobjectArray jout) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);

/**
* Create array interface.
*/
namespace linalg = xgboost::linalg;
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
xgboost::Context ctx;
auto t_data = linalg::MakeTensorView(
ctx.Device(),
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
num_features);
auto s_array = linalg::ArrayInterfaceStr(t_data);

/**
* Create configuration object.
*/
xgboost::Json config{xgboost::Object{}};
config["cache_id"] = xgboost::Integer{};
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
config["missing"] = xgboost::Number{static_cast<float>(missing)};
config["strict_shape"] = xgboost::Boolean{true};
std::string s_config;
xgboost::Json::Dump(config, &s_config);

/**
* Handle base margin
*/
BoosterHandle proxy{nullptr};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please help add a test for prediction with margin. With regression using reg:squarederror:

pred0 = booster.inplace_predict(X, margin)
pred1 = booster.inplace_predict(X) + margin
assert pred0 == pred1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis Hello! I wrote a test, but it appears to be failing. I think it may be related to this: #9536

Is the in-place prediction not working properly with margins at the moment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug happens when you are running prediction with base margin on GPU while the input is from the CPU or the other way around. I don't think it's relevant to this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to push your tests, I can take a look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much! I pushed the new tests. Let me know if you have any questions on anything! Right now what I'm seeing with the failing test is that the prediction( x with margin) != prediction(x) + margin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! That did the trick. I'll post your changes from below shortly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Please merge/rebase the latest branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Let me know if there's anything else needed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Let me know if there's anything else needed!

Fix CI errors, if there's any. ;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the CI failed with some of the new C++ code you added from the sovrn#9 branch. I tried taking a look at it, but by no means am I a C++ expert. It looks like some sort of parameter type mismatch is occurring. Could you take a look?


float *margin{nullptr};
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL(
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
}

bst_ulong const *out_shape;
bst_ulong out_dim;
float const *result;
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
&out_dim, &result);

jenv->ReleaseFloatArrayElements(jdata, data, 0);
if (proxy) {
XGDMatrixFree(proxy);
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
}

if (ret != 0) {
return ret;
}

std::size_t n{1};
for (std::size_t i = 0; i < out_dim; ++i) {
n *= out_shape[i];
}

jfloatArray jarray = jenv->NewFloatArray(n);

jenv->SetFloatArrayRegion(jarray, 0, n, result);
jenv->SetObjectArrayElement(jout, 0, jarray);

API_END();
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
Expand Down
16 changes: 8 additions & 8 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading