From a8b6986904c6af313dbfafb0e6709d0b9442d90d Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 00:32:28 +0800 Subject: [PATCH 1/7] Port multi classes metrics to CUDA. --- src/metric/elementwise_metric.cu | 35 +-- src/metric/metric.cc | 2 +- .../{metric_param.h => metric_common.h} | 23 ++ src/metric/multiclass_metric.cc | 128 +-------- src/metric/multiclass_metric.cu | 251 ++++++++++++++++++ tests/cpp/metric/test_multiclass_metric.cc | 14 +- tests/cpp/metric/test_multiclass_metric.cu | 5 + 7 files changed, 304 insertions(+), 154 deletions(-) rename src/metric/{metric_param.h => metric_common.h} (50%) create mode 100644 src/metric/multiclass_metric.cu create mode 100644 tests/cpp/metric/test_multiclass_metric.cu diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index b694feafd519..6a4497601d16 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2018 by Contributors + * Copyright 2015-2019 by Contributors * \file elementwise_metric.cc * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen @@ -9,7 +9,7 @@ #include #include -#include "metric_param.h" +#include "metric_common.h" #include "../common/math.h" #include "../common/common.h" @@ -28,29 +28,9 @@ namespace metric { DMLC_REGISTRY_FILE_TAG(elementwise_metric); template -class MetricsReduction { +class ElementWiseMetricsReduction { public: - class PackedReduceResult { - double residue_sum_; - double weights_sum_; - friend MetricsReduction; - - public: - XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {} - XGBOOST_DEVICE PackedReduceResult(double residue, double weight) : - residue_sum_{residue}, weights_sum_{weight} {} - - XGBOOST_DEVICE - PackedReduceResult operator+(PackedReduceResult const& other) const { - return PackedReduceResult { residue_sum_ + other.residue_sum_, - weights_sum_ + other.weights_sum_ }; - } - double Residue() const { return residue_sum_; } - double Weights() const { return weights_sum_; } - }; - - public: - explicit MetricsReduction(EvalRow policy) : + explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {} PackedReduceResult CpuReduceMetrics( @@ -144,9 +124,8 @@ class MetricsReduction { DeviceReduceMetrics(id, index, weights, labels, preds); } - for (size_t i = 0; i < devices.Size(); ++i) { - result.residue_sum_ += res_per_device[i].residue_sum_; - result.weights_sum_ += res_per_device[i].weights_sum_; + for (auto const& res : res_per_device) { + result += res; } } #endif // defined(XGBOOST_USE_CUDA) @@ -370,7 +349,7 @@ struct EvalEWiseBase : public Metric { MetricParam param_; - MetricsReduction reducer_; + ElementWiseMetricsReduction reducer_; }; XGBOOST_REGISTER_METRIC(RMSE, "rmse") diff --git a/src/metric/metric.cc b/src/metric/metric.cc index 076a0ce91c95..8d3d9d9280cc 100644 --- a/src/metric/metric.cc +++ b/src/metric/metric.cc @@ -6,7 +6,7 @@ #include #include -#include "metric_param.h" +#include "metric_common.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); diff --git a/src/metric/metric_param.h b/src/metric/metric_common.h similarity index 50% rename from src/metric/metric_param.h rename to src/metric/metric_common.h index 3fd51a9f58be..6b727496999f 100644 --- a/src/metric/metric_param.h +++ b/src/metric/metric_common.h @@ -25,6 +25,29 @@ struct MetricParam : public dmlc::Parameter { }; }; +class PackedReduceResult { + double residue_sum_; + double weights_sum_; + +public: + XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {} + XGBOOST_DEVICE PackedReduceResult(double residue, double weight) + : residue_sum_{residue}, weights_sum_{weight} {} + + XGBOOST_DEVICE + PackedReduceResult operator+(PackedReduceResult const &other) const { + return PackedReduceResult{residue_sum_ + other.residue_sum_, + weights_sum_ + other.weights_sum_}; + } + PackedReduceResult &operator+=(PackedReduceResult const &other) { + this->residue_sum_ += other.residue_sum_; + this->weights_sum_ += other.weights_sum_; + return *this; + } + double Residue() const { return residue_sum_; } + double Weights() const { return weights_sum_; } +}; + } // namespace metric } // namespace xgboost diff --git a/src/metric/multiclass_metric.cc b/src/metric/multiclass_metric.cc index a1db79901fbc..7733a334f5c0 100644 --- a/src/metric/multiclass_metric.cc +++ b/src/metric/multiclass_metric.cc @@ -1,126 +1,8 @@ /*! - * Copyright 2015 by Contributors - * \file multiclass_metric.cc - * \brief evaluation metrics for multiclass classification. - * \author Kailong Chen, Tianqi Chen + * Copyright 2019 XGBoost contributors */ -#include -#include -#include -#include "../common/math.h" +// Dummy file to keep the CUDA conditional compile trick. -namespace xgboost { -namespace metric { -// tag the this file, used by force static link later. -DMLC_REGISTRY_FILE_TAG(multiclass_metric); - -/*! - * \brief base class of multi-class evaluation - * \tparam Derived the name of subclass - */ -template -struct EvalMClassBase : public Metric { - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK(preds.Size() % info.labels_.Size() == 0) - << "label and prediction size not match"; - const size_t nclass = preds.Size() / info.labels_.Size(); - CHECK_GE(nclass, 1U) - << "mlogloss and merror are only used for multi-class classification," - << " use logloss for binary classification"; - const auto ndata = static_cast(info.labels_.Size()); - double sum = 0.0, wsum = 0.0; - int label_error = 0; - - const auto& labels = info.labels_.HostVector(); - const auto& weights = info.weights_.HostVector(); - const std::vector& h_preds = preds.HostVector(); - -#pragma omp parallel for reduction(+: sum, wsum) schedule(static) - for (bst_omp_uint i = 0; i < ndata; ++i) { - const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f; - auto label = static_cast(labels[i]); - if (label >= 0 && label < static_cast(nclass)) { - sum += Derived::EvalRow(label, - h_preds.data() + i * nclass, - nclass) * wt; - wsum += wt; - } else { - label_error = label; - } - } - CHECK(label_error >= 0 && label_error < static_cast(nclass)) - << "MultiClassEvaluation: label must be in [0, num_class)," - << " num_class=" << nclass << " but found " << label_error << " in label"; - - double dat[2]; dat[0] = sum, dat[1] = wsum; - if (distributed) { - rabit::Allreduce(dat, 2); - } - return Derived::GetFinal(dat[0], dat[1]); - } - /*! - * \brief to be implemented by subclass, - * get evaluation result from one row - * \param label label of current instance - * \param pred prediction value of current instance - * \param nclass number of class in the prediction - */ - inline static bst_float EvalRow(int label, - const bst_float *pred, - size_t nclass); - /*! - * \brief to be overridden by subclass, final transformation - * \param esum the sum statistics returned by EvalRow - * \param wsum sum of weight - */ - inline static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; - } - - private: - // used to store error message - const char *error_msg_; -}; - -/*! \brief match error */ -struct EvalMatchError : public EvalMClassBase { - const char* Name() const override { - return "merror"; - } - inline static bst_float EvalRow(int label, - const bst_float *pred, - size_t nclass) { - return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast(label); - } -}; - -/*! \brief match error */ -struct EvalMultiLogLoss : public EvalMClassBase { - const char* Name() const override { - return "mlogloss"; - } - inline static bst_float EvalRow(int label, - const bst_float *pred, - size_t nclass) { - const bst_float eps = 1e-16f; - auto k = static_cast(label); - if (pred[k] > eps) { - return -std::log(pred[k]); - } else { - return -std::log(eps); - } - } -}; - -XGBOOST_REGISTER_METRIC(MatchError, "merror") -.describe("Multiclass classification error.") -.set_body([](const char* param) { return new EvalMatchError(); }); - -XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss") -.describe("Multiclass negative loglikelihood.") -.set_body([](const char* param) { return new EvalMultiLogLoss(); }); -} // namespace metric -} // namespace xgboost +#if !defined(XGBOOST_USE_CUDA) +#include "multiclass_metric.cu" +#endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu new file mode 100644 index 000000000000..95090fe465a9 --- /dev/null +++ b/src/metric/multiclass_metric.cu @@ -0,0 +1,251 @@ +/*! + * Copyright 2015 by Contributors + * \file multiclass_metric.cc + * \brief evaluation metrics for multiclass classification. + * \author Kailong Chen, Tianqi Chen + */ +#include +#include +#include + +#if defined(XGBOOST_USE_CUDA) +#include +#include // thrust::plus<> +#include "../common/device_helpers.cuh" +#endif // XGBOOST_USE_CUDA + +#include "metric_common.h" +#include "../common/math.h" + +namespace xgboost { +namespace metric { +// tag the this file, used by force static link later. +DMLC_REGISTRY_FILE_TAG(multiclass_metric); + +template +class MultiClassMetricsReduction { + public: + MultiClassMetricsReduction() {} + + PackedReduceResult CpuReduceMetrics( + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds, + const size_t n_class) const { + size_t ndata = labels.Size(); + + const auto& h_labels = labels.HostVector(); + const auto& h_weights = weights.HostVector(); + const auto& h_preds = preds.HostVector(); + + bst_float residue_sum = 0; + bst_float weights_sum = 0; + int label_error = 0; + bool has_weight = h_weights.size() > 0; + bool const is_null_weight = weights.Size() == 0; + +#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) + for (omp_ulong idx = 0; idx < ndata; ++idx) { + bst_float weight = is_null_weight ? 1.0f : h_weights[idx]; + auto label = static_cast(h_labels[idx]); + if (label >= 0 && label < static_cast(n_class)) { + residue_sum += EvalRowPolicy::EvalRow( + label, h_preds.data() + idx * n_class, n_class) * weight; + weights_sum += weight; + } else { + label_error = label; + } + } + CHECK(label_error >= 0 && label_error < static_cast(n_class)) + << "MultiClassEvaluation: label must be in [0, num_class)," + << " num_class=" << n_class << " but found " << label_error << " in label"; + PackedReduceResult res { residue_sum, weights_sum }; + return res; + } + +#if defined(XGBOOST_USE_CUDA) + + PackedReduceResult DeviceReduceMetrics( + GPUSet::GpuIdType device_id, + size_t device_index, + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds, + const size_t n_class) { + size_t n_data = labels.DeviceSize(device_id); + + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + n_data; + + auto s_labels = labels.DeviceSpan(device_id); + auto s_preds = preds.DeviceSpan(device_id); + auto s_weights = weights.DeviceSpan(device_id); + + bool const is_null_weight = weights.Size() == 0; + + PackedReduceResult result = thrust::transform_reduce( + thrust::cuda::par(allocators_.at(device_index)), + begin, end, + [=] XGBOOST_DEVICE(size_t idx) { + bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; + bst_float residue = 0; + auto label = static_cast(s_labels[idx]); + if (label >= 0 && label < static_cast(n_class)) { + residue = EvalRowPolicy::EvalRow( + label, &s_preds[idx * n_class], n_class) * weight; + } + return PackedReduceResult{ residue, weight }; + }, + PackedReduceResult(), + thrust::plus()); + + return result; + } + +#endif // XGBOOST_USE_CUDA + + PackedReduceResult Reduce( + GPUSet devices, + size_t n_class, + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds) { + PackedReduceResult result; + + if (devices.IsEmpty()) { + result = CpuReduceMetrics(weights, labels, preds, n_class); + } +#if defined(XGBOOST_USE_CUDA) + else { // NOLINT + if (allocators_.size() != devices.Size()) { + allocators_.clear(); + allocators_.resize(devices.Size()); + } + preds.Reshard(devices); + labels.Reshard(devices); + weights.Reshard(devices); + std::vector res_per_device(devices.Size()); + +#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) + for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) { + dh::safe_cuda(cudaSetDevice(id)); + size_t index = devices.Index(id); + res_per_device.at(index) = + DeviceReduceMetrics(id, index, weights, labels, preds, n_class); + } + + for (auto const& res : res_per_device) { + result += res; + } + } +#endif // defined(XGBOOST_USE_CUDA) + return result; + } + + private: +#if defined(XGBOOST_USE_CUDA) + std::vector allocators_; +#endif // defined(XGBOOST_USE_CUDA) +}; + +/*! + * \brief base class of multi-class evaluation + * \tparam Derived the name of subclass + */ +template +struct EvalMClassBase : public Metric { + void Configure( + const std::vector >& args) override { + param_.InitAllowUnknown(args); + } + + bst_float Eval(const HostDeviceVector &preds, + const MetaInfo &info, + bool distributed) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK(preds.Size() % info.labels_.Size() == 0) + << "label and prediction size not match"; + const size_t nclass = preds.Size() / info.labels_.Size(); + CHECK_GE(nclass, 1U) + << "mlogloss and merror are only used for multi-class classification," + << " use logloss for binary classification"; + const auto ndata = static_cast(info.labels_.Size()); + + const auto& labels = info.labels_.HostVector(); + const auto& weights = info.weights_.HostVector(); + const std::vector& h_preds = preds.HostVector(); + + GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata); + auto result = reducer_.Reduce(devices, nclass, info.weights_, info.labels_, preds); + double dat[2] { result.Residue(), result.Weights() }; + + if (distributed) { + rabit::Allreduce(dat, 2); + } + return Derived::GetFinal(dat[0], dat[1]); + } + /*! + * \brief to be implemented by subclass, + * get evaluation result from one row + * \param label label of current instance + * \param pred prediction value of current instance + * \param nclass number of class in the prediction + */ + XGBOOST_DEVICE static bst_float EvalRow(int label, + const bst_float *pred, + size_t nclass); + /*! + * \brief to be overridden by subclass, final transformation + * \param esum the sum statistics returned by EvalRow + * \param wsum sum of weight + */ + inline static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } + + private: + MultiClassMetricsReduction reducer_; + MetricParam param_; + // used to store error message + const char *error_msg_; +}; + +/*! \brief match error */ +struct EvalMatchError : public EvalMClassBase { + const char* Name() const override { + return "merror"; + } + XGBOOST_DEVICE static bst_float EvalRow(int label, + const bst_float *pred, + size_t nclass) { + return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast(label); + } +}; + +/*! \brief match error */ +struct EvalMultiLogLoss : public EvalMClassBase { + const char* Name() const override { + return "mlogloss"; + } + XGBOOST_DEVICE static bst_float EvalRow(int label, + const bst_float *pred, + size_t nclass) { + const bst_float eps = 1e-16f; + auto k = static_cast(label); + if (pred[k] > eps) { + return -std::log(pred[k]); + } else { + return -std::log(eps); + } + } +}; + +XGBOOST_REGISTER_METRIC(MatchError, "merror") +.describe("Multiclass classification error.") +.set_body([](const char* param) { return new EvalMatchError(); }); + +XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss") +.describe("Multiclass negative loglikelihood.") +.set_body([](const char* param) { return new EvalMultiLogLoss(); }); +} // namespace metric +} // namespace xgboost diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc index fe5a3788d1a2..dc2ce6f43858 100644 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ b/tests/cpp/metric/test_multiclass_metric.cc @@ -3,8 +3,17 @@ #include "../helpers.h" -TEST(Metric, MultiClassError) { +using Arg = std::pair; + +#if defined(__CUDACC__) +#define N_GPU() Arg{"n_gpus", "1"} +#else +#define N_GPU() Arg{"n_gpus", "0"} +#endif + +TEST(Metric, DeclareUnifiedTest(MultiClassError)) { xgboost::Metric * metric = xgboost::Metric::Create("merror"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "merror"); EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); EXPECT_NEAR(GetMetricEval( @@ -17,8 +26,9 @@ TEST(Metric, MultiClassError) { delete metric; } -TEST(Metric, MultiClassLogLoss) { +TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { xgboost::Metric * metric = xgboost::Metric::Create("mlogloss"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "mlogloss"); EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); EXPECT_NEAR(GetMetricEval( diff --git a/tests/cpp/metric/test_multiclass_metric.cu b/tests/cpp/metric/test_multiclass_metric.cu new file mode 100644 index 000000000000..84117f5b3dee --- /dev/null +++ b/tests/cpp/metric/test_multiclass_metric.cu @@ -0,0 +1,5 @@ +/*! + * Copyright 2018 XGBoost contributors + */ +// Dummy file to keep the CUDA conditional compile trick. +#include "test_multiclass_metric.cc" \ No newline at end of file From c126415774d960fada6b916fb512c78fd9da1386 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 02:24:02 +0800 Subject: [PATCH 2/7] Correct year. --- src/metric/metric_common.h | 2 +- src/metric/multiclass_metric.cu | 2 +- tests/cpp/metric/test_multiclass_metric.cu | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 6b727496999f..5c2173027202 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -1,5 +1,5 @@ /*! - * Copyright 2018 by Contributors + * Copyright 2018-2019 by Contributors * \file metric_param.cc */ #ifndef XGBOOST_METRIC_METRIC_PARAM_H_ diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 95090fe465a9..6e00e0609c72 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2019 by Contributors * \file multiclass_metric.cc * \brief evaluation metrics for multiclass classification. * \author Kailong Chen, Tianqi Chen diff --git a/tests/cpp/metric/test_multiclass_metric.cu b/tests/cpp/metric/test_multiclass_metric.cu index 84117f5b3dee..8a087565b3da 100644 --- a/tests/cpp/metric/test_multiclass_metric.cu +++ b/tests/cpp/metric/test_multiclass_metric.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2018 XGBoost contributors + * Copyright 2019 XGBoost contributors */ // Dummy file to keep the CUDA conditional compile trick. #include "test_multiclass_metric.cc" \ No newline at end of file From 408642a7f2e40dedee80631c2559b7f2ebab9c39 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 02:32:49 +0800 Subject: [PATCH 3/7] Remove syncing. --- src/metric/multiclass_metric.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 6e00e0609c72..0db0309b9485 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -171,10 +171,6 @@ struct EvalMClassBase : public Metric { << " use logloss for binary classification"; const auto ndata = static_cast(info.labels_.Size()); - const auto& labels = info.labels_.HostVector(); - const auto& weights = info.weights_.HostVector(); - const std::vector& h_preds = preds.HostVector(); - GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata); auto result = reducer_.Reduce(devices, nclass, info.weights_, info.labels_, preds); double dat[2] { result.Residue(), result.Weights() }; From dd38ccce99cd25af3f3ac22194a7b426a1029799 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 12:38:12 +0800 Subject: [PATCH 4/7] Linter. --- src/metric/metric_common.h | 8 ++++---- src/metric/multiclass_metric.cu | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 5c2173027202..293d0a235926 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -2,8 +2,8 @@ * Copyright 2018-2019 by Contributors * \file metric_param.cc */ -#ifndef XGBOOST_METRIC_METRIC_PARAM_H_ -#define XGBOOST_METRIC_METRIC_PARAM_H_ +#ifndef XGBOOST_METRIC_METRIC_COMMON_H_ +#define XGBOOST_METRIC_METRIC_COMMON_H_ #include #include "../common/common.h" @@ -29,7 +29,7 @@ class PackedReduceResult { double residue_sum_; double weights_sum_; -public: + public: XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {} XGBOOST_DEVICE PackedReduceResult(double residue, double weight) : residue_sum_{residue}, weights_sum_{weight} {} @@ -51,4 +51,4 @@ class PackedReduceResult { } // namespace metric } // namespace xgboost -#endif // XGBOOST_METRIC_METRIC_PARAM_H_ +#endif // XGBOOST_METRIC_METRIC_COMMON_H_ diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 0db0309b9485..cdcdf71e5a09 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -25,7 +25,7 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric); template class MultiClassMetricsReduction { public: - MultiClassMetricsReduction() {} + MultiClassMetricsReduction() = default; PackedReduceResult CpuReduceMetrics( const HostDeviceVector& weights, @@ -41,7 +41,6 @@ class MultiClassMetricsReduction { bst_float residue_sum = 0; bst_float weights_sum = 0; int label_error = 0; - bool has_weight = h_weights.size() > 0; bool const is_null_weight = weights.Size() == 0; #pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) From 7186cc46ba1b7dc3108fbe38bdb790f790ad1896 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 12:43:36 +0800 Subject: [PATCH 5/7] Includes. --- src/metric/elementwise_metric.cu | 6 +++--- src/metric/multiclass_metric.cu | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 6a4497601d16..a9221be849bf 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -14,10 +14,10 @@ #include "../common/common.h" #if defined(XGBOOST_USE_CUDA) -#include +#include // thrust::cuda::par +#include // thrust::plus<> #include -#include -#include // thrust::plus<> +#include #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index cdcdf71e5a09..459edf8ea3e6 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -8,15 +8,19 @@ #include #include +#include "metric_common.h" +#include "../common/math.h" +#include "../common/common.h" + #if defined(XGBOOST_USE_CUDA) +#include // thrust::cuda::par +#include // thrust::plus<> #include -#include // thrust::plus<> +#include + #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA -#include "metric_common.h" -#include "../common/math.h" - namespace xgboost { namespace metric { // tag the this file, used by force static link later. From 143c9a3c7ee425d57652b93cb60a42e782f2099c Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 14:03:35 +0800 Subject: [PATCH 6/7] Add multi gpu tests. --- src/metric/multiclass_metric.cu | 3 ++- tests/cpp/metric/test_multiclass_metric.cc | 29 ++++++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 459edf8ea3e6..eb93de045f2b 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -63,6 +63,7 @@ class MultiClassMetricsReduction { << "MultiClassEvaluation: label must be in [0, num_class)," << " num_class=" << n_class << " but found " << label_error << " in label"; PackedReduceResult res { residue_sum, weights_sum }; + return res; } @@ -124,7 +125,7 @@ class MultiClassMetricsReduction { allocators_.clear(); allocators_.resize(devices.Size()); } - preds.Reshard(devices); + preds.Reshard(GPUDistribution::Granular(devices, n_class)); labels.Reshard(devices); weights.Reshard(devices); std::vector res_per_device(devices.Size()); diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc index dc2ce6f43858..79954784593a 100644 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ b/tests/cpp/metric/test_multiclass_metric.cc @@ -1,5 +1,6 @@ // Copyright by Contributors #include +#include #include "../helpers.h" @@ -11,13 +12,13 @@ using Arg = std::pair; #define N_GPU() Arg{"n_gpus", "0"} #endif -TEST(Metric, DeclareUnifiedTest(MultiClassError)) { +inline void TestMultiClassError(std::vector args) { xgboost::Metric * metric = xgboost::Metric::Create("merror"); - metric->Configure({N_GPU()}); + metric->Configure(args); ASSERT_STREQ(metric->Name(), "merror"); EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); EXPECT_NEAR(GetMetricEval( - metric, {1, 0, 0, 0, 1, 0, 0, 0, 1}, {0, 1, 2}), 0, 1e-10); + metric, {1, 0, 0, 0, 1, 0, 0, 0, 1}, {0, 1, 2}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}, {0, 1, 2}), @@ -26,9 +27,13 @@ TEST(Metric, DeclareUnifiedTest(MultiClassError)) { delete metric; } -TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { +TEST(Metric, DeclareUnifiedTest(MultiClassError)) { + TestMultiClassError({N_GPU()}); +} + +inline void TestMultiClassLogLoss(std::vector args) { xgboost::Metric * metric = xgboost::Metric::Create("mlogloss"); - metric->Configure({N_GPU()}); + metric->Configure(args); ASSERT_STREQ(metric->Name(), "mlogloss"); EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); EXPECT_NEAR(GetMetricEval( @@ -40,3 +45,17 @@ TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { delete metric; } + +TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { + TestMultiClassLogLoss({N_GPU()}); +} + +#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__) +TEST(Metric, MGPU_MultiClassError) { + TestMultiClassError({Arg{"n_gpus", "-1"}}); + TestMultiClassError({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}}); + + TestMultiClassLogLoss({Arg{"n_gpus", "-1"}}); + TestMultiClassLogLoss({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}}); +} +#endif // defined(XGBOOST_USE_NCCL) From 248247bed665c9bd64994ac31c6618ea6169701d Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 15 Apr 2019 16:17:01 +0800 Subject: [PATCH 7/7] Check label error. --- src/metric/multiclass_metric.cu | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index eb93de045f2b..88af0014ed5a 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -28,6 +28,12 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric); template class MultiClassMetricsReduction { + void CheckLabelError(int32_t label_error, size_t n_class) const { + CHECK(label_error >= 0 && label_error < static_cast(n_class)) + << "MultiClassEvaluation: label must be in [0, num_class)," + << " num_class=" << n_class << " but found " << label_error << " in label"; + } + public: MultiClassMetricsReduction() = default; @@ -59,9 +65,7 @@ class MultiClassMetricsReduction { label_error = label; } } - CHECK(label_error >= 0 && label_error < static_cast(n_class)) - << "MultiClassEvaluation: label must be in [0, num_class)," - << " num_class=" << n_class << " but found " << label_error << " in label"; + CheckLabelError(label_error, n_class); PackedReduceResult res { residue_sum, weights_sum }; return res; @@ -86,6 +90,8 @@ class MultiClassMetricsReduction { auto s_weights = weights.DeviceSpan(device_id); bool const is_null_weight = weights.Size() == 0; + auto s_label_error = label_error_.GetSpan(1); + s_label_error[0] = 0; PackedReduceResult result = thrust::transform_reduce( thrust::cuda::par(allocators_.at(device_index)), @@ -94,14 +100,17 @@ class MultiClassMetricsReduction { bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; bst_float residue = 0; auto label = static_cast(s_labels[idx]); - if (label >= 0 && label < static_cast(n_class)) { + if (label >= 0 && label < static_cast(n_class)) { residue = EvalRowPolicy::EvalRow( label, &s_preds[idx * n_class], n_class) * weight; + } else { + s_label_error[0] = label; } return PackedReduceResult{ residue, weight }; }, PackedReduceResult(), thrust::plus()); + CheckLabelError(s_label_error[0], n_class); return result; } @@ -148,6 +157,7 @@ class MultiClassMetricsReduction { private: #if defined(XGBOOST_USE_CUDA) + dh::PinnedMemory label_error_; std::vector allocators_; #endif // defined(XGBOOST_USE_CUDA) };