diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index f6913ffc769f..0d89d6a230ac 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -116,6 +116,60 @@ public void setParams(Map params) throws XGBoostError { } } + /** + * Get attributes stored in the Booster as a Map. + * + * @return A map contain attribute pairs. + * @throws XGBoostError native error + */ + public final Map getAttrs() throws XGBoostError { + String[][] attrNames = new String[1][]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttrNames(handle, attrNames)); + Map attrMap = new HashMap<>(); + for (String name: attrNames[0]) { + attrMap.put(name, this.getAttr(name)); + } + return attrMap; + } + + /** + * Get attribute from the Booster. + * + * @param key attribute key + * @return attribute value + * @throws XGBoostError native error + */ + public final String getAttr(String key) throws XGBoostError { + String[] attrValue = new String[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttr(handle, key, attrValue)); + return attrValue[0]; + } + + /** + * Set attribute to the Booster. + * + * @param key attribute key + * @param value attribute value + * @throws XGBoostError native error + */ + public final void setAttr(String key, String value) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetAttr(handle, key, value)); + } + + /** + * Set attributes to the Booster. + * + * @param attrs attributes key-value map + * @throws XGBoostError native error + */ + public void setAttrs(Map attrs) throws XGBoostError { + if (attrs != null) { + for (Map.Entry entry : attrs.entrySet()) { + setAttr(entry.getKey(), entry.getValue()); + } + } + } + /** * Update the booster for one iteration. * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index d20a27ba12cd..e797d67aa3a2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -114,6 +114,7 @@ public final static native int XGBoosterDumpModelEx(long handle, String fmap, in public final static native int XGBoosterDumpModelExWithFeatures( long handle, String[] feature_names, int with_stats, String format, String[][] out_strings); + public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index f86ab8e18624..bb2d5e9e576c 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -32,6 +32,48 @@ import scala.collection.mutable class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) extends Serializable with KryoSerializable { + /** + * Get attributes stored in the Booster as a Map. + * + * @return A map contain attribute pairs. + */ + @throws(classOf[XGBoostError]) + def getAttrs: Map[String, String] = { + booster.getAttrs.asScala.toMap + } + + /** + * Get attribute from the Booster. + * + * @param key attr name + * @return attr value + */ + @throws(classOf[XGBoostError]) + def getAttr(key: String): String = { + booster.getAttr(key) + } + + /** + * Set attribute to the Booster. + * + * @param key attr name + * @param value attr value + */ + @throws(classOf[XGBoostError]) + def setAttr(key: String, value: String): Unit = { + booster.setAttr(key, value) + } + + /** + * set attributes + * + * @param params attributes key-value map + */ + @throws(classOf[XGBoostError]) + def setAttrs(params: Map[String, String]): Unit = { + booster.setAttrs(params.asJava) + } + /** * Set parameter to the Booster. * diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 02ab0b2289b3..5e268093df84 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -706,6 +706,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel return ret; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetAttrNames + * Signature: (I[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNames + (JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jout) { + BoosterHandle handle = (BoosterHandle) jhandle; + bst_ulong len = 0; + char **result; + int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result); + + jsize jlen = (jsize) len; + jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); + for(int i=0 ; iSetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); + } + jenv->SetObjectArrayElement(jout, 0, jinfos); + + return ret; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetAttr + * Signature: (JLjava/lang/String;[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) { + BoosterHandle handle = (BoosterHandle) jhandle; + const char* key = jenv->GetStringUTFChars(jkey, 0); + const char* result; + int success; + int ret = XGBoosterGetAttr(handle, key, &result, &success); + //release + if (key) jenv->ReleaseStringUTFChars(jkey, key); + + if (success > 0) { + jstring jret = jenv->NewStringUTF(result); + jenv->SetObjectArrayElement(jout, 0, jret); + } + + return ret; +}; + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterSetAttr + * Signature: (JLjava/lang/String;Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) { + BoosterHandle handle = (BoosterHandle) jhandle; + const char* key = jenv->GetStringUTFChars(jkey, 0); + const char* value = jenv->GetStringUTFChars(jvalue, 0); + int ret = XGBoosterSetAttr(handle, key, value); + //release + if (key) jenv->ReleaseStringUTFChars(jkey, key); + if (value) jenv->ReleaseStringUTFChars(jvalue, value); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterLoadRabitCheckpoint diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 3994d7825076..96eaa97b27cc 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -231,6 +231,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures (JNIEnv *, jclass, jlong, jobjectArray, jint, jstring, jobjectArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetAttrNames + * Signature: (I[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNames + (JNIEnv *, jclass, jlong, jobjectArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetAttr diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index c03174261839..a86144e66e6d 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -617,4 +617,34 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException { TestCase.assertTrue(booster1error == booster2error); TestCase.assertTrue(tempBoosterError > booster2error); } + + /** + * test set/get attributes to/from a booster + * + * @throws XGBoostError + */ + @Test + public void testSetAndGetAttrs() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + booster.setAttr("testKey1", "testValue1"); + TestCase.assertEquals(booster.getAttr("testKey1"), "testValue1"); + booster.setAttr("testKey1", "testValue2"); + TestCase.assertEquals(booster.getAttr("testKey1"), "testValue2"); + + booster.setAttrs(new HashMap(){{ + put("aa", "AA"); + put("bb", "BB"); + put("cc", "CC"); + }}); + + Map attr = booster.getAttrs(); + TestCase.assertEquals(attr.size(), 4); + TestCase.assertEquals(attr.get("testKey1"), "testValue2"); + TestCase.assertEquals(attr.get("aa"), "AA"); + TestCase.assertEquals(attr.get("bb"), "BB"); + TestCase.assertEquals(attr.get("cc"), "CC"); + } }