Skip to content

Commit

Permalink
pass correct parameters to xgboost library
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jul 13, 2020
1 parent 99f4ac7 commit ab3412d
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,18 @@ object XGBoost extends Serializable {
throw new RuntimeException("Something wrong for task context")
}
val resources = tc.resources()

if (resources.contains("gpu")) {
val addrs = resources("gpu").addresses
if (addrs.size > 1) {
// TODO should we throw exception ?
logger.warn("XGBoost only supports 1 gpu per worker")
}
logger.warn("===================================xxxxxxxxxxxx")
// take the first one
addrs.head.toInt
} else {
throw new RuntimeException("gpu is not allocated by spark, " +
"pls check if gpu scheduling is enabled")
"please check if gpu scheduling is enabled")
}
}

Expand Down Expand Up @@ -398,15 +399,17 @@ object XGBoost extends Serializable {
} else {
getGPUAddrFromResources
}
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("gpu_id" -> gpuId)
}
logger.info(params)
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
SXGBoost.train(watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
}
Expand Down

0 comments on commit ab3412d

Please sign in to comment.