Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMM integration plugin #5873

Merged
merged 48 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b7a322d
[CI] Add RMM as an optional dependency
hcho3 Jul 8, 2020
e15845d
Replace caching allocator with pool allocator from RMM
hcho3 Jul 8, 2020
812c209
Revert "Replace caching allocator with pool allocator from RMM"
hcho3 Jul 9, 2020
a891112
Use rmm::mr::get_default_resource()
hcho3 Jul 9, 2020
b5eb54d
Try setting default resource (doesn't work yet)
hcho3 Jul 9, 2020
6abd4c0
Allocate pool_mr in the heap
hcho3 Jul 9, 2020
2bdbc23
Prevent leaking pool_mr handle
hcho3 Jul 9, 2020
c723632
Separate EXPECT_DEATH() in separate test suite suffixed DeathTest
hcho3 Jul 9, 2020
78c2254
Turn off death tests for RMM
hcho3 Jul 9, 2020
a520fa1
Address reviewer's feedback
hcho3 Jul 9, 2020
a73391c
Prevent leaking of cuda_mr
hcho3 Jul 10, 2020
309efc0
Merge remote-tracking branch 'origin/master' into add_rmm
hcho3 Jul 22, 2020
fa4ec11
Fix Jenkinsfile syntax
hcho3 Jul 22, 2020
871fc29
Remove unnecessary function in Jenkinsfile
hcho3 Jul 22, 2020
48051df
[CI] Install NCCL into RMM container
hcho3 Jul 22, 2020
c0a05ce
Run Python tests
hcho3 Jul 22, 2020
c12e0a6
Try building with RMM, CUDA 10.0
hcho3 Jul 22, 2020
a3e0e2f
Do not use RMM for CUDA 10.0 target
hcho3 Jul 22, 2020
3aeab69
Actually test for test_rmm flag
hcho3 Jul 22, 2020
862d580
Fix TestPythonGPU
hcho3 Jul 22, 2020
2a064bf
Use CNMeM allocator, since pool allocator doesn't yet support multiGPU
hcho3 Jul 29, 2020
ab4e7b4
Merge branch 'master' into add_rmm
hcho3 Jul 29, 2020
dd05d7b
Merge remote-tracking branch 'origin/master' into add_rmm
hcho3 Jul 29, 2020
a4da8c5
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 Jul 29, 2020
789021f
Use 10.0 container to build RMM-enabled XGBoost
hcho3 Jul 30, 2020
f27d836
Revert "Use 10.0 container to build RMM-enabled XGBoost"
hcho3 Jul 31, 2020
a4b86a9
Fix Jenkinsfile
hcho3 Jul 31, 2020
e5eb262
[CI] Assign larger /dev/shm to NCCL
hcho3 Jul 31, 2020
4cf7f00
Use 10.2 artifact to run multi-GPU Python tests
hcho3 Jul 31, 2020
d023a50
Add CUDA 10.0 -> 11.0 cross-version test; remove CUDA 10.0 target
hcho3 Jul 31, 2020
abc64a3
Rename Conda env rmm_test -> gpu_test
hcho3 Jul 31, 2020
1e7e42e
Use env var to opt into CNMeM pool for C++ tests
hcho3 Jul 31, 2020
f1eeaff
Merge branch 'master' into add_rmm
hcho3 Jul 31, 2020
1069ae0
Use identical CUDA version for RMM builds and tests
hcho3 Jul 31, 2020
99a7520
Use Pytest fixtures to enable RMM pool in Python tests
hcho3 Aug 6, 2020
ecc16ec
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 Aug 7, 2020
92d1481
Move RMM to plugin/CMakeLists.txt; use PLUGIN_RMM
hcho3 Aug 7, 2020
e74fd0d
Use per-device MR; use command arg in gtest
hcho3 Aug 8, 2020
2ee04b3
Set CMake prefix path to use Conda env
hcho3 Aug 8, 2020
87422a2
Use 0.15 nightly version of RMM
hcho3 Aug 8, 2020
9021a75
Remove unnecessary header
hcho3 Aug 8, 2020
377580a
Fix a unit test when cudf is missing
Aug 8, 2020
2f3c532
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 Aug 9, 2020
3df7cc3
Add RMM demos
hcho3 Aug 10, 2020
567fb33
Remove print()
hcho3 Aug 10, 2020
1e63c46
Use HostDeviceVector in GPU predictor
hcho3 Aug 11, 2020
ad216c5
Simplify pytest setup; use LocalCUDACluster fixture
hcho3 Aug 11, 2020
b4195cd
Address reviewers' commments
hcho3 Aug 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ address, leak, undefined and thread.")
## Plugins
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
## TODO: 1. Add check if DPC++ compiler is used for building
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
Expand All @@ -84,6 +85,9 @@ endif (R_LIB AND GOOGLE_TEST)
if (USE_AVX)
message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.")
endif (USE_AVX)
if (PLUGIN_RMM AND NOT (USE_CUDA))
message(SEND_ERROR "`PLUGIN_RMM` must be enabled with `USE_CUDA` flag.")
endif (PLUGIN_RMM AND NOT (USE_CUDA))
if (ENABLE_ALL_WARNINGS)
if ((NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") AND (NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
message(SEND_ERROR "ENABLE_ALL_WARNINGS is only available for Clang and GCC.")
Expand Down
57 changes: 41 additions & 16 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pipeline {
'build-gpu-cuda10.0': { BuildCUDA(cuda_version: '10.0') },
// The build-gpu-* builds below use Ubuntu image
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') },
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2', build_rmm: true) },
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') },
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
Expand All @@ -89,11 +89,12 @@ pipeline {
script {
parallel ([
'test-python-cpu': { TestPythonCPU() },
'test-python-gpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2') },
// artifact_cuda_version doesn't apply to RMM tests; RMM tests will always match CUDA version between artifact and host env
'test-python-gpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', test_rmm: true) },
'test-python-gpu-cuda11.0-cross': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '11.0') },
'test-python-gpu-cuda11.0': { TestPythonGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', multi_gpu: true) },
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') },
'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', multi_gpu: true, test_rmm: true) },
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', test_rmm: true) },
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') },
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
Expand Down Expand Up @@ -280,6 +281,22 @@ def BuildCUDA(args) {
}
echo 'Stashing C++ test executable (testxgboost)...'
stash name: "xgboost_cpp_tests_cuda${args.cuda_version}", includes: 'build/testxgboost'
if (args.build_rmm) {
echo "Build with CUDA ${args.cuda_version} and RMM"
container_type = "rmm"
docker_binary = "docker"
docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}"
sh """
rm -rf build/
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_via_cmake.sh --conda-env=gpu_test -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON ${arch_flag}
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "cd python-package && rm -rf dist/* && python setup.py bdist_wheel --universal"
${dockerRun} ${container_type} ${docker_binary} ${docker_args} python tests/ci_build/rename_whl.py python-package/dist/*.whl ${commit_id} manylinux2010_x86_64
"""
echo 'Stashing Python wheel...'
stash name: "xgboost_whl_rmm_cuda${args.cuda_version}", includes: 'python-package/dist/*.whl'
echo 'Stashing C++ test executable (testxgboost)...'
stash name: "xgboost_cpp_tests_rmm_cuda${args.cuda_version}", includes: 'build/testxgboost'
}
deleteDir()
}
}
Expand Down Expand Up @@ -366,18 +383,15 @@ def TestPythonGPU(args) {
def container_type = "gpu"
def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
if (args.multi_gpu) {
echo "Using multiple GPUs"
// Allocate extra space in /dev/shm to enable NCCL
def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'"
sh """
${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu
"""
} else {
echo "Using a single GPU"
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu
"""
def mgpu_indicator = (args.multi_gpu) ? 'mgpu' : 'gpu'
// Allocate extra space in /dev/shm to enable NCCL
def docker_extra_params = (args.multi_gpu) ? "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" : ''
sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator}"
if (args.test_rmm) {
sh "rm -rfv build/ python-package/dist/"
unstash name: "xgboost_whl_rmm_cuda${args.host_cuda_version}"
unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator} --use-rmm-pool"
}
deleteDir()
}
Expand Down Expand Up @@ -408,6 +422,17 @@ def TestCppGPU(args) {
def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} build/testxgboost"
if (args.test_rmm) {
sh "rm -rfv build/"
unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
echo "Test C++, CUDA ${args.host_cuda_version} with RMM"
container_type = "rmm"
docker_binary = "nvidia-docker"
docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "source activate gpu_test && build/testxgboost --use-rmm-pool --gtest_filter=-*DeathTest.*"
"""
}
deleteDir()
}
}
Expand Down
31 changes: 31 additions & 0 deletions demo/rmm_plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL)
====================================================================
[RAPIDS Memory Manager (RMM)](https://github.com/rapidsai/rmm) library provides a collection of
efficient memory allocators for NVIDIA GPUs. It is now possible to use XGBoost with memory
allocators provided by RMM, by enabling the RMM integration plugin.

The demos in this directory highlights one RMM allocator in particular: **the pool sub-allocator**.
This allocator addresses the slow speed of `cudaMalloc()` by allocating a large chunk of memory
upfront. Subsequent allocations will draw from the pool of already allocated memory and thus avoid
the overhead of calling `cudaMalloc()` directly. See
[this GTC talk slides](https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf)
for more details.

Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this,
run CMake with option `-DPLUGIN_RMM=ON` (`-DUSE_CUDA=ON` also required):
```
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON
make -j4
```
CMake will attempt to locate the RMM library in your build environment. You may choose to build
RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you
should specify the location of RMM with the CMake prefix:
```
# If using Conda:
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
# If using RMM installed with a custom location
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm
```

* [Using RMM with a single GPU](./rmm_singlegpu.py)
* [Using RMM with a local Dask cluster consisting of multiple GPUs](./rmm_mgpu_with_dask.py)
27 changes: 27 additions & 0 deletions demo/rmm_plugin/rmm_mgpu_with_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import xgboost as xgb
from sklearn.datasets import make_classification
import dask
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

def main(client):
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
X = dask.array.from_array(X)
y = dask.array.from_array(y)
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)

params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
'tree_method': 'gpu_hist'}
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
evals=[(dtrain, 'train')])
bst = output['booster']
history = output['history']
for i, e in enumerate(history['train']['merror']):
print(f'[{i}] train-merror: {e}')

if __name__ == '__main__':
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to
# LocalCUDACluster constructor.
with LocalCUDACluster(rmm_pool_size='2GB') as cluster:
with Client(cluster) as client:
main(client)
14 changes: 14 additions & 0 deletions demo/rmm_plugin/rmm_singlegpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import xgboost as xgb
import rmm
from sklearn.datasets import make_classification

# Initialize RMM pool allocator
rmm.reinitialize(pool_allocator=True)
trivialfis marked this conversation as resolved.
Show resolved Hide resolved

X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
dtrain = xgb.DMatrix(X, label=y)

params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
'tree_method': 'gpu_hist'}
# XGBoost will automatically use the RMM pool allocator
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train')])
19 changes: 19 additions & 0 deletions plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@ if (PLUGIN_DENSE_PARSER)
target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc)
endif (PLUGIN_DENSE_PARSER)

if (PLUGIN_RMM)
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
find_path(RMM_INCLUDE "rmm"
HINTS "$ENV{RMM_ROOT}/include")

find_library(RMM_LIBRARY "rmm"
HINTS "$ENV{RMM_ROOT}/lib" "$ENV{RMM_ROOT}/build")

if ((NOT RMM_LIBRARY) OR (NOT RMM_INCLUDE))
message(FATAL_ERROR "Could not locate RMM library")
endif ()

message(STATUS "RMM: RMM_LIBRARY set to ${RMM_LIBRARY}")
message(STATUS "RMM: RMM_INCLUDE set to ${RMM_INCLUDE}")

target_include_directories(objxgboost PUBLIC ${RMM_INCLUDE})
target_link_libraries(objxgboost PUBLIC ${RMM_LIBRARY} cuda)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_RMM=1)
endif (PLUGIN_RMM)

if (PLUGIN_UPDATER_ONEAPI)
add_library(oneapi_plugin OBJECT
${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _is_cudf_df(data):
import cudf
except ImportError:
return False
return isinstance(data, cudf.DataFrame)
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
trivialfis marked this conversation as resolved.
Show resolved Hide resolved


def _cudf_array_interfaces(data):
Expand Down
36 changes: 32 additions & 4 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@

#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#endif
#endif // XGBOOST_USE_NCCL

#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)

Expand Down Expand Up @@ -370,12 +375,21 @@ inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
}

namespace detail {

#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
template <typename T>
using XGBBaseDeviceAllocator = rmm::mr::thrust_allocator<T>;
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
template <typename T>
using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1

/**
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
*/
template <class T>
struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
using SuperT = thrust::device_malloc_allocator<T>;
struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
using SuperT = XGBBaseDeviceAllocator<T>;
using pointer = thrust::device_ptr<T>; // NOLINT
template<typename U>
struct rebind // NOLINT
Expand All @@ -391,10 +405,15 @@ struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
return SuperT::deallocate(ptr, n);
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBDefaultDeviceAllocatorImpl()
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{0}) {}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
};

/**
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction.
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs
* allocations if verbose. Does not initialise memory on construction.
*/
template <class T>
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
Expand Down Expand Up @@ -752,6 +771,15 @@ xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
return ToSpan(vec, offset, size);
}

template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
std::unique_ptr<VectorT>& vec,
IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
return ToSpan(*vec.get(), offset, size);
}

// thrust begin, similiar to std::begin
template <typename T>
thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT
Expand Down
24 changes: 15 additions & 9 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ __global__ void PredictKernel(Data data,

class DeviceModel {
public:
dh::device_vector<RegTree::Node> nodes;
dh::device_vector<size_t> tree_segments;
dh::device_vector<int> tree_group;
// Need to lazily construct the vectors because GPU id is only known at runtime
std::unique_ptr<dh::device_vector<RegTree::Node>> nodes;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we are constructing the predictor lazily?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device vectors were being constructed before cudaSetDevice() was called. The device vectors need access to correct CUDA context at the time of construction, so I've delayed construction of the device vectors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit worrying if this is necessary, not sure if we can guarantee this behaviour across xgboost. Also wouldn't it be easier to place DeviceModel inside a unique pointer?

Copy link
Collaborator Author

@hcho3 hcho3 Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, this was not necessary, since Thrust's device vector builds its memory resource (MR) lazily for each GPU it was being used. On the other hand, if we use RMM allocator with device vectors, then the correct CUDA context needs to be set (with cudaSetDevice()) prior to the construction of the device vector.

For now this line works, but in the longer term I can design a new device vector class that lazily constructs the device MR.

Also wouldn't it be easier to place DeviceModel inside a unique pointer?

That won't work, because DeviceModel has a separate Init() function, and the correct CUDA context isn't set until we call Init() function.

void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) {
dh::safe_cuda(cudaSetDevice(gpu_id));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an earlier point in the program that we can set the device? e.g. in the learner as soon as it receives the parameter gpu_id?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can use HostDeviceVector, which is lazy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I replaced it with HostDeviceVector.

std::unique_ptr<dh::device_vector<size_t>> tree_segments;
std::unique_ptr<dh::device_vector<int>> tree_group;
size_t tree_beg_; // NOLINT
size_t tree_end_; // NOLINT
int num_group;
Expand All @@ -224,16 +225,16 @@ class DeviceModel {
const thrust::host_vector<size_t>& h_tree_segments,
const thrust::host_vector<RegTree::Node>& h_nodes,
size_t tree_begin, size_t tree_end) {
nodes.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
nodes->resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes->data().get(), h_nodes.data(),
sizeof(RegTree::Node) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(),
tree_segments->resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpyAsync(tree_segments->data().get(), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice));
tree_group.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(),
tree_group->resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpyAsync(tree_group->data().get(), model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice));
this->tree_beg_ = tree_begin;
Expand All @@ -243,6 +244,11 @@ class DeviceModel {

void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) {
dh::safe_cuda(cudaSetDevice(gpu_id));
// Allocate device vectors using correct GPU ID context
nodes.reset(new dh::device_vector<RegTree::Node>());
tree_segments.reset(new dh::device_vector<size_t>());
tree_group.reset(new dh::device_vector<int>());

CHECK_EQ(model.param.size_leaf_vector, 0);
// Copy decision trees to device
thrust::host_vector<size_t> h_tree_segments{};
Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ ENV PATH=/opt/python/bin:$PATH

# Create new Conda environment with cuDF, Dask, and cuPy
RUN \
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.7 cudf=0.14 cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
python=3.7 cudf=0.15* rmm=0.15* cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis

ENV GOSU_VERSION 1.10
Expand Down
Loading