Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved multi-node multi-GPU random forests. #4238

Merged
merged 8 commits into from
Mar 12, 2019
21 changes: 16 additions & 5 deletions src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ class ColumnSampler {
float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f};
GlobalRandomEngine rng_;

std::shared_ptr<std::vector<int>> ColSample
(std::shared_ptr<std::vector<int>> p_features, float colsample) const {
(std::shared_ptr<std::vector<int>> p_features, float colsample) {
if (colsample == 1.0f) return p_features;
const auto& features = *p_features;
CHECK_GT(features.size(), 0);
Expand All @@ -100,17 +101,24 @@ class ColumnSampler {
auto& new_features = *p_new_features;
new_features.resize(features.size());
std::copy(features.begin(), features.end(), new_features.begin());
std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom());
std::shuffle(new_features.begin(), new_features.end(), rng_);
new_features.resize(n);
std::sort(new_features.begin(), new_features.end());

// ensure that new_features are the same across ranks
rabit::Broadcast(&new_features, 0);

return p_new_features;
}

public:
/**
* \brief Column sampler constructor.
* \note This constructor synchronizes the RNG seed across processes.
*/
ColumnSampler() {
uint32_t seed = common::GlobalRandom()();
rabit::Broadcast(&seed, sizeof(seed), 0);
rng_.seed(seed);
}

/**
* \brief Initialise this object before use.
*
Expand Down Expand Up @@ -153,6 +161,9 @@ class ColumnSampler {
* \return The sampled feature set.
* \note If colsample_bynode_ < 1.0, this method creates a new feature set each time it
* is called. Therefore, it should be called only once per node.
* \note With distributed xgboost, this function must be called exactly once for the
* construction of each tree node, and must be called the same number of times in each
* process and with the same parameters to return the same feature set across processes.
*/
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
Expand Down
48 changes: 48 additions & 0 deletions tests/distributed-gpu/distributed_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Common functions for distributed GPU tests."""
import time
import xgboost as xgb


def run_test(name, params_fun, num_round):
"""Runs a distributed GPU test."""
# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

params = params_fun(rank)

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(params, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.%s.%d" % (name, rank)
bst.dump_model(model_name, with_stats=True)
time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.%s.%d" % (name, i)
for j in range(0, world):
if i == j:
continue
with open(model_name_root, 'r') as model_root:
contents_root = model_root.read()
model_name_rank = "test.model.%s.%d" % (name, j)
with open(model_name_rank, 'r') as model_rank:
contents_rank = model_rank.read()
if contents_root != contents_rank:
raise Exception(
('Worker models diverged: test.model.%s.%d '
'differs from test.model.%s.%d') % (name, i, name, j))

xgb.rabit.finalize()
18 changes: 17 additions & 1 deletion tests/distributed-gpu/runtests-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,20 @@ PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --c

echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \
python test_gpu_basic_4x1.py
python test_gpu_basic_4x1.py

echo -e "\n ====== 5. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=4 \
python test_gpu_rf_1x4.py

echo -e "\n ====== 6. RF distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
python test_gpu_rf_2x2.py

echo -e "\n ====== 7. RF distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
python test_gpu_rf_asym.py

echo -e "\n ====== 8. RF distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \
python test_gpu_rf_4x1.py
64 changes: 14 additions & 50 deletions tests/distributed-gpu/test_gpu_basic_1x4.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,15 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.1x4." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.1x4." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.1x4." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.1x4.{} differs from test.model.1x4.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
import distributed_gpu as dgpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these python files can be combined into a single file using command line arguments to configure the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def params_fun(rank):
return {
'n_gpus': 1,
'gpu_id': rank,
'tree_method': 'gpu_hist',
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic'
}

dgpu.run_test('1x4', params_fun, 20)
64 changes: 14 additions & 50 deletions tests/distributed-gpu/test_gpu_basic_2x2.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,15 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 2, 'gpu_id': 2*rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.2x2." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.2x2." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.2x2." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.2x2.{} differs from test.model.2x2.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
import distributed_gpu as dgpu

def params_fun(rank):
return {
'n_gpus': 2,
'gpu_id': 2*rank,
'tree_method': 'gpu_hist',
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic'
}

dgpu.run_test('2x2', params_fun, 20)
47 changes: 14 additions & 33 deletions tests/distributed-gpu/test_gpu_basic_4x1.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 4, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have root save its model
if(rank == 0):
model_name = "test.model.4x1." + str(rank)
bst.dump_model(model_name, with_stats=True)
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
import distributed_gpu as dgpu

def params_fun(rank):
return {
'n_gpus': 4,
'gpu_id': rank,
'tree_method': 'gpu_hist',
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic'
}

dgpu.run_test('4x1', params_fun, 20)
67 changes: 14 additions & 53 deletions tests/distributed-gpu/test_gpu_basic_asym.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,15 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
if rank == 0:
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
else:
param = {'n_gpus': 3, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.asym." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.asym." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.asym." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.asym.{} differs from test.model.asym.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
import distributed_gpu as dgpu

def params_fun(rank):
return {
'gpu_id': rank,
'tree_method': 'gpu_hist',
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'n_gpus': 1 if rank == 0 else 3
}

dgpu.run_test('asym', params_fun, 20)
18 changes: 18 additions & 0 deletions tests/distributed-gpu/test_gpu_rf_1x4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/python
import distributed_gpu as dgpu

def params_fun(rank):
return {
'n_gpus': 1,
'gpu_id': rank,
'tree_method': 'gpu_hist',
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'subsample': 0.5,
'colsample_bynode': 0.5,
'num_parallel_tree': 20
}

dgpu.run_test('rf.1x4', params_fun, 1)
Loading