Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] support operating attributes of booster #4336

Merged
merged 1 commit into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,60 @@ public void setParams(Map<String, Object> params) throws XGBoostError {
}
}

/**
* Get attributes stored in the Booster as a Map.
*
* @return A map contain attribute pairs.
* @throws XGBoostError native error
*/
public final Map<String, String> getAttrs() throws XGBoostError {
String[][] attrNames = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttrNames(handle, attrNames));
Map<String, String> attrMap = new HashMap<>();
for (String name: attrNames[0]) {
attrMap.put(name, this.getAttr(name));
}
return attrMap;
}

/**
* Get attribute from the Booster.
*
* @param key attribute key
* @return attribute value
* @throws XGBoostError native error
*/
public final String getAttr(String key) throws XGBoostError {
String[] attrValue = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttr(handle, key, attrValue));
return attrValue[0];
}

/**
* Set attribute to the Booster.
*
* @param key attribute key
* @param value attribute value
* @throws XGBoostError native error
*/
public final void setAttr(String key, String value) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetAttr(handle, key, value));
}

/**
* Set attributes to the Booster.
*
* @param attrs attributes key-value map
* @throws XGBoostError native error
*/
public void setAttrs(Map<String, String> attrs) throws XGBoostError {
if (attrs != null) {
for (Map.Entry<String, String> entry : attrs.entrySet()) {
setAttr(entry.getKey(), entry.getValue());
}
}
}

/**
* Update the booster for one iteration.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public final static native int XGBoosterDumpModelEx(long handle, String fmap, in
public final static native int XGBoosterDumpModelExWithFeatures(
long handle, String[] feature_names, int with_stats, String format, String[][] out_strings);

public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,48 @@ import scala.collection.mutable
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {

/**
* Get attributes stored in the Booster as a Map.
*
* @return A map contain attribute pairs.
*/
@throws(classOf[XGBoostError])
def getAttrs: Map[String, String] = {
booster.getAttrs.asScala.toMap
}

/**
* Get attribute from the Booster.
*
* @param key attr name
* @return attr value
*/
@throws(classOf[XGBoostError])
def getAttr(key: String): String = {
booster.getAttr(key)
}

/**
* Set attribute to the Booster.
*
* @param key attr name
* @param value attr value
*/
@throws(classOf[XGBoostError])
def setAttr(key: String, value: String): Unit = {
booster.setAttr(key, value)
}

/**
* set attributes
*
* @param params attributes key-value map
*/
@throws(classOf[XGBoostError])
def setAttrs(params: Map[String, String]): Unit = {
booster.setAttrs(params.asJava)
}

/**
* Set parameter to the Booster.
*
Expand Down
62 changes: 62 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetAttrNames
* Signature: (I[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNames
(JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0;
char **result;
int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result);

jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
for(int i=0 ; i<jlen; i++) {
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
}
jenv->SetObjectArrayElement(jout, 0, jinfos);

return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* key = jenv->GetStringUTFChars(jkey, 0);
const char* result;
int success;
int ret = XGBoosterGetAttr(handle, key, &result, &success);
//release
if (key) jenv->ReleaseStringUTFChars(jkey, key);

if (success > 0) {
jstring jret = jenv->NewStringUTF(result);
jenv->SetObjectArrayElement(jout, 0, jret);
}

return ret;
};

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* key = jenv->GetStringUTFChars(jkey, 0);
const char* value = jenv->GetStringUTFChars(jvalue, 0);
int ret = XGBoosterSetAttr(handle, key, value);
//release
if (key) jenv->ReleaseStringUTFChars(jkey, key);
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
Expand Down
8 changes: 8 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,34 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException {
TestCase.assertTrue(booster1error == booster2error);
TestCase.assertTrue(tempBoosterError > booster2error);
}

/**
* test set/get attributes to/from a booster
*
* @throws XGBoostError
*/
@Test
public void testSetAndGetAttrs() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

Booster booster = trainBooster(trainMat, testMat);
booster.setAttr("testKey1", "testValue1");
TestCase.assertEquals(booster.getAttr("testKey1"), "testValue1");
booster.setAttr("testKey1", "testValue2");
TestCase.assertEquals(booster.getAttr("testKey1"), "testValue2");

booster.setAttrs(new HashMap<String, String>(){{
put("aa", "AA");
put("bb", "BB");
put("cc", "CC");
}});

Map<String, String> attr = booster.getAttrs();
TestCase.assertEquals(attr.size(), 4);
TestCase.assertEquals(attr.get("testKey1"), "testValue2");
TestCase.assertEquals(attr.get("aa"), "AA");
TestCase.assertEquals(attr.get("bb"), "BB");
TestCase.assertEquals(attr.get("cc"), "CC");
}
}