-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow JVM-Package to access inplace predict method #9167
Changes from all commits
c5d448e
ee0a029
5ce8d65
058aec3
79ce356
86e38d4
0c5b2df
312c874
bed3035
57d88da
b923c95
3b88aaf
e132136
d623163
e4f5ef8
bccd539
8b4885d
30ed3ca
509c880
272ed03
cc8e79b
81c23c2
dec1375
e91026a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict | |
return ret; | ||
} | ||
|
||
/* | ||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI | ||
* Method: XGBoosterPredictFromDense | ||
* Signature: (J[FJJFIII[F[[F)I | ||
*/ | ||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense( | ||
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features, | ||
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type, | ||
jfloatArray jmargin, jobjectArray jout) { | ||
API_BEGIN(); | ||
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle); | ||
|
||
/** | ||
* Create array interface. | ||
*/ | ||
namespace linalg = xgboost::linalg; | ||
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr); | ||
xgboost::Context ctx; | ||
auto t_data = linalg::MakeTensorView( | ||
ctx.Device(), | ||
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows, | ||
num_features); | ||
auto s_array = linalg::ArrayInterfaceStr(t_data); | ||
|
||
/** | ||
* Create configuration object. | ||
*/ | ||
xgboost::Json config{xgboost::Object{}}; | ||
config["cache_id"] = xgboost::Integer{}; | ||
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)}; | ||
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)}; | ||
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)}; | ||
config["missing"] = xgboost::Number{static_cast<float>(missing)}; | ||
config["strict_shape"] = xgboost::Boolean{true}; | ||
std::string s_config; | ||
xgboost::Json::Dump(config, &s_config); | ||
|
||
/** | ||
* Handle base margin | ||
*/ | ||
BoosterHandle proxy{nullptr}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please help add a test for prediction with margin. With regression using
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @trivialfis Hello! I wrote a test, but it appears to be failing. I think it may be related to this: #9536 Is the in-place prediction not working properly with margins at the moment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bug happens when you are running prediction with base margin on GPU while the input is from the CPU or the other way around. I don't think it's relevant to this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to push your tests, I can take a look. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you very much! I pushed the new tests. Let me know if you have any questions on anything! Right now what I'm seeing with the failing test is that the prediction( x with margin) != prediction(x) + margin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! That did the trick. I'll post your changes from below shortly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! Please merge/rebase the latest branch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! Let me know if there's anything else needed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fix CI errors, if there's any. ;-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the CI failed with some of the new C++ code you added from the sovrn#9 branch. I tried taking a look at it, but by no means am I a C++ expert. It looks like some sort of parameter type mismatch is occurring. Could you take a look? |
||
|
||
float *margin{nullptr}; | ||
if (jmargin) { | ||
margin = jenv->GetFloatArrayElements(jmargin, nullptr); | ||
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy)); | ||
JVM_CHECK_CALL( | ||
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin))); | ||
} | ||
|
||
bst_ulong const *out_shape; | ||
bst_ulong out_dim; | ||
float const *result; | ||
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape, | ||
&out_dim, &result); | ||
|
||
jenv->ReleaseFloatArrayElements(jdata, data, 0); | ||
if (proxy) { | ||
XGDMatrixFree(proxy); | ||
jenv->ReleaseFloatArrayElements(jmargin, margin, 0); | ||
} | ||
|
||
if (ret != 0) { | ||
return ret; | ||
} | ||
|
||
std::size_t n{1}; | ||
for (std::size_t i = 0; i < out_dim; ++i) { | ||
n *= out_shape[i]; | ||
} | ||
|
||
jfloatArray jarray = jenv->NewFloatArray(n); | ||
|
||
jenv->SetFloatArrayRegion(jarray, 0, n, result); | ||
jenv->SetObjectArrayElement(jout, 0, jarray); | ||
|
||
API_END(); | ||
} | ||
|
||
/* | ||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI | ||
* Method: XGBoosterLoadModel | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wbo4958 I just learned about the existence of
BigDenseMatrix
, do you think we should return the prediction in that?