Skip to content

Commit

Permalink
[dask] Honor nthreads from dask worker. (#5414)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 15, 2020
1 parent 21b671a commit 761a5db
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 16 deletions.
3 changes: 1 addition & 2 deletions demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def main(client):
# evaluation metrics.
output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 1,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
Expand All @@ -37,6 +36,6 @@ def main(client):

if __name__ == '__main__':
# or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=1) as cluster:
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)
3 changes: 1 addition & 2 deletions demo/dask/gpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def main(client):
# evaluation metrics.
output = xgb.dask.train(client,
{'verbosity': 2,
'nthread': 1,
# Golden line for GPU training
'tree_method': 'gpu_hist'},
dtrain,
Expand All @@ -41,6 +40,6 @@ def main(client):
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
# `n_workers` represents the number of GPUs since we use one GPU per worker
# process.
with LocalCUDACluster(n_workers=2, threads_per_worker=1) as cluster:
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)
27 changes: 26 additions & 1 deletion doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ illustrates the basic usage:
output = xgb.dask.train(client,
{'verbosity': 2,
'nthread': 1,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
Expand Down Expand Up @@ -76,6 +75,32 @@ Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-L
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
for more examples.

*******
Threads
*******

XGBoost has built in support for parallel computation through threads by the setting
``nthread`` parameter (``n_jobs`` for scikit-learn). If these parameters are set, they
will override the configuration in Dask. For example:

.. code-block:: python
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
threads in each process for both training and prediction. But if ``nthread`` parameter is
set:

.. code-block:: python
output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 8,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
XGBoost will use 8 threads in each training process.

*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
Expand Down
33 changes: 24 additions & 9 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
# - Ranking


LOGGER = logging.getLogger('[xgboost.dask]')


def _start_tracker(host, n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
Expand All @@ -62,7 +65,7 @@ def _assert_dask_support():
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
logging.warning(msg)
LOGGER.warning(msg)


class RabitContext:
Expand All @@ -75,11 +78,11 @@ def __init__(self, args):

def __enter__(self):
rabit.init(self.args)
logging.debug('-------------- rabit say hello ------------------')
LOGGER.debug('-------------- rabit say hello ------------------')

def __exit__(self, *args):
rabit.finalize()
logging.debug('--------------- rabit say bye ------------------')
LOGGER.debug('--------------- rabit say bye ------------------')


def concat(value):
Expand Down Expand Up @@ -301,7 +304,7 @@ def get_worker_data(self, worker):
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
logging.warning(msg)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types)
Expand All @@ -324,7 +327,8 @@ def get_worker_data(self, worker):
weight=weights,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types)
feature_types=self.feature_types,
nthread=worker.nthreads)
return dmatrix

def get_worker_data_shape(self, worker):
Expand Down Expand Up @@ -399,7 +403,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):

def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
logging.info('Training on %s', str(worker_addr))
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker()
with RabitContext(rabit_args):
local_dtrain = dtrain.get_worker_data(worker)
Expand All @@ -415,6 +419,15 @@ def dispatched_train(worker_addr):

local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys():
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys():
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
Expand Down Expand Up @@ -477,15 +490,17 @@ def predict(client, model, data, *args):

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
logging.info('Predicting on %d', worker_id)
LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=feature_names,
feature_types=feature_types,
missing=missing)
missing=missing,
nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
Expand All @@ -495,7 +510,7 @@ def dispatched_predict(worker_id):

def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
logging.info('Trying to get data shape on %d', worker_id)
LOGGER.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []
Expand Down
9 changes: 7 additions & 2 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xgboost as xgb
import sys
import numpy as np
import json

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_from_dask_dataframe():


def test_from_dask_array():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
Expand All @@ -74,11 +75,15 @@ def test_from_dask_array():
# force prediction to be computed
prediction = prediction.compute()

single_node_predt = result['booster'].predict(
booster = result['booster']
single_node_predt = booster.predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)

config = json.loads(booster.save_config())
assert int(config['learner']['generic_param']['nthread']) == 5


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
Expand Down

0 comments on commit 761a5db

Please sign in to comment.