Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less restrictive monotone constraints #2305

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
1c47fac
Added 2 parameters to enable penalization of monotone splits and enab…
Jun 25, 2019
d56ff9b
Added getters and new parameters like parents of nodes in the trees.
Jun 25, 2019
b65ace1
Fill the feature_is_monotone variable in the trees.
Jun 26, 2019
cb9917f
Added utility function to know the number of bins of a feature.
Jun 26, 2019
8fea762
Added some debugging checks.
Jun 26, 2019
d412ca6
Leaf splits now keep track of the depth.
Jun 26, 2019
37010ce
Added the tree as a parameter in the function UpdateBestSplitsFromHis…
Jun 26, 2019
3ff354d
Added a new struct to keep track of the constraints, in a more precis…
Jun 26, 2019
d8563da
Added variables to keep track of the constraints efficiently.
Jun 26, 2019
10c5cbb
Added a penalty function for monotone splits.
Jun 26, 2019
d69e4ff
Added core functions to go through the tree, and update split when co…
Jun 26, 2019
e4ec6a0
Modified the way constraints are handled in feature histograms.
Jun 26, 2019
3ea72af
Removed old constraints that are not used anymore.
Jun 26, 2019
692af0d
Added a function to refit leaves at the end of the training.
Jun 26, 2019
7aaec4a
Updated tests.
Jun 26, 2019
21b8e32
Small bug fix.
Aug 6, 2019
c8cea6f
Made testing time for new tests more reasonable.
Aug 7, 2019
1b0713d
Removed code specific to the slow method.
Sep 9, 2019
742c8e0
Move the monotone penalty in the monotone_cosntraints file.
Sep 9, 2019
e50b141
Change name of current Constraints to LeafConstraints.
Sep 11, 2019
c49640e
Created a class for current constraints.
Sep 12, 2019
5e20b4b
Created the structure SplittingConstraints to encapsulate monotone co…
Sep 17, 2019
61ad8d6
Refactoring to make CurrentConstraints an array of SplittingConstraints.
Sep 17, 2019
f020bfb
Using the standard kEpsilon.
Sep 17, 2019
1e30c19
Making ComputeBestSplitForFeature static.
Sep 26, 2019
d2de571
Move the functions used for the constraints in the constraints files.…
Sep 27, 2019
e1125a9
Moved another monotone-constraints-related function in the monotone c…
Sep 27, 2019
dbe74f2
Added LearnerState structure to have functions with less arguments.
Sep 27, 2019
f1717b6
Remove commented unused code.
Oct 7, 2019
b0202e1
Removed unused variable splits_per_leaf_.
Oct 7, 2019
54b5d47
Remove duplicated function.
Oct 22, 2019
033de6f
Remove useless class members.
Oct 22, 2019
7272a35
Added a getter for splits_per_leaf_.
Oct 23, 2019
dbc3f07
Removed could_be_splittable_ useful only in the Slow method.
Oct 23, 2019
9593c3c
Remove old comment.
Oct 23, 2019
6040cbd
Pass constraint class directly to GetSolitGains.
Oct 24, 2019
3f10afe
Switched SplittingConstraints from a reference to a pointer.
Oct 24, 2019
ed61bab
Grouped the best_constraints from feature_histogram in a class.
Oct 24, 2019
f6eb56a
Changed constraints to a nullptr when there are no monotone constraints.
Oct 24, 2019
eb2c412
Changed data_partition from unique_ptr to regular pointer.
Oct 31, 2019
8d6ef50
Splitted Splitting constraints in left SplittingConstraint and RightS…
Nov 1, 2019
ceae2bf
Constraints classes are now passed to CalculateSplittedLeafOutput ins…
Nov 1, 2019
b632791
Fix bug.
Nov 1, 2019
d291b99
Use nullptr's when there are no constraints.
Dec 18, 2019
c0f1cc0
Clarified an if statement.
Dec 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ Learning Control Parameters

- dropout rate: a fraction of previous trees to drop during the dropout

- ``monotone_penalty`` :raw-html:`<a id="monotone_penalty" title="Permalink to this parameter" href="#monotone_penalty">&#x1F517;&#xFE0E;</a>`, default = ``0.``, type = double, aliases: ``monotone_splits_penalty``, constraints: ``0.0 <= monotone_penalty (< max_depth, if max_depth > 0)``

- used only if ``monotone_constraints`` is set

- monotone penalty: a penalization of 0 equals to no penalization. A penalization parameter X forbids any monotone splits on the first X (rounded down) level(s) of the tree. The penalty applied to monotone splits on a given depth is a continuous, increasing function the penalization parameter

- ``monotone_precise_method`` :raw-html:`<a id="monotone_precise_method" title="Permalink to this parameter" href="#monotone_precise_method">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool, aliases: ``monotone_constraints_precise_mode``

- used only if ``monotone_constraints`` is set

- monotone precise method`: if set to false then the program will run as fast as without constraints, but the results may be over-constrained. If set to true, then the program will be slower, but results will be better. Note that if there are categorical features, in the dataset, they will be splitted using the fast method regardless of this parameter. Also, the parameter can only be set to true if the missing handle is disabled

- ``max_drop`` :raw-html:`<a id="max_drop" title="Permalink to this parameter" href="#max_drop">&#x1F517;&#xFE0E;</a>`, default = ``50``, type = int

- used only in ``dart``
Expand Down
12 changes: 12 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ struct Config {
// desc = dropout rate: a fraction of previous trees to drop during the dropout
double drop_rate = 0.1;

// alias = monotone_splits_penalty
// check = >=0.0
// check = <max_depth; if max_depth > 0
// desc = used only if ``monotone_constraints`` is set
// desc = monotone penalty: a penalization of 0 equals to no penalization. A penalization parameter X forbids any monotone splits on the first X (rounded down) level(s) of the tree. The penalty applied to monotone splits on a given depth is a continuous, increasing function the penalization parameter
double monotone_penalty = 0.;

// alias = monotone_constraints_precise_mode
// desc = used only if ``monotone_constraints`` is set
// desc = monotone precise mode: if set to false then the program will run as fast as without constraints, but the results may be over-constrained. If set to true, then the program will be slower, but results will be better. Note that if there are categorical features, in the dataset, they will be splitted using the fast method regardless of this parameter. Also, the parameter can only be set to true if the missing handle is disabled
bool monotone_precise_mode = false;

// desc = used only in ``dart``
// desc = max number of dropped trees during one boosting iteration
// desc = ``<=0`` means no limit
Expand Down
85 changes: 78 additions & 7 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Tree {
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, double left_weight, double right_weight,
float gain, MissingType missing_type, bool default_left);
float gain, MissingType missing_type, bool default_left, bool feature_is_monotone);

/*!
* \brief Performing a split on tree leaves, with categorical feature
Expand All @@ -80,9 +80,14 @@ class Tree {
* \param gain Split gain
* \return The index of new leaf.
*/
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type);

int SplitCategorical(int leaf, int feature, int real_feature,
const uint32_t *threshold_bin, int num_threshold_bin,
const uint32_t *threshold, int num_threshold,
double left_value, double right_value, int left_cnt,
int right_cnt, double left_weight, double right_weight,
float gain, MissingType missing_type,
bool feature_is_monotone);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is feature_is_monotone needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it is not needed as we don't support categorical features. But that makes it more consistent with the standard split, and in the future maybe we will make categorical variables supported for monotonic constraints. Should I remove it?


/*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
Expand Down Expand Up @@ -124,6 +129,24 @@ class Tree {
inline int PredictLeafIndex(const double* feature_values) const;
inline int PredictLeafIndexByMap(const std::unordered_map<int, double>& feature_values) const;

// Get node parent
inline int node_parent(int node_idx) const;
// Get leaf parent
inline int leaf_parent(int node_idx) const;

// Get children
inline int left_child(int node_idx) const;
inline int right_child(int node_idx) const;

// Get if the feature is in a monotone subtree
inline bool leaf_is_in_monotone_subtree(int leaf_idx) const;

inline double internal_value(int node_idx) const;

inline uint32_t threshold_in_bin(int node_idx) const;

// Get the feature corresponding to the split
inline int split_feature_inner(int node_idx) const;

inline void PredictContrib(const double* feature_values, int num_features, double* output);

Expand Down Expand Up @@ -302,8 +325,10 @@ class Tree {
}
}

inline void Split(int leaf, int feature, int real_feature, double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain);
inline void Split(int leaf, int feature, int real_feature, double left_value,
double right_value, int left_cnt, int right_cnt, double left_weight,
double right_weight,float gain, bool feature_is_monotone);

/*!
* \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record
Expand Down Expand Up @@ -402,12 +427,22 @@ class Tree {
std::vector<int> leaf_depth_;
double shrinkage_;
int max_depth_;
// add parent node information
std::vector<int> node_parent_;
// Keeps track of the monotone splits above the leaf
std::vector<bool> leaf_is_in_monotone_subtree_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to store these states into tree model? Or other structures only for MC?
Since some users don't use MC, this will change the model format and cause the compatibility problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another suggestion: it seems there are lots of changes in FeatureHistogram. This will hurt the readability of current codes. Could we decouple MC with other parts? Like an independent class, which could be called in FeatureHistogram.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will hurt the readability of current codes. Could we decouple MC with other parts?

We were very careful to implement this in a way that limited the impact on readability and, in my opinion, I think that we did a good job in this regard, given the complexity of what's being implemented. In fact, in my opinion, the current implementation of monotonic constraints isn't particularly readable. If anything, we've made it more readable, although, obviously, more complex.

If you have any specific examples where you feel readability has been impacted let us know and we can discuss it on an example-by-example basis.

Could we decouple MC with other parts?

We could, but I'm not convinced this would make things substantially more readable, and it would be some effort on our part to make such changes. It could even hurt readability due to increased indirection.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually I mean the maintainability. All implementation of MC is in one place, not distributed in the project. And i think good maintainability is also more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will not be able to work on this during this week and next week, sorry about that. But I will make the necessary changes to improve the code after that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guolinke about the new states in the Tree class which would cause compatibility issues. First, the states are absolutely necessary to the functioning of monotone constraints as we implemented them, there is no way we can do without. And I think that these are sensible states that should belong to the tree class (I could even see these reused for other purposes).

An alternative I can think about would be to create an alternative class which would look like tree, but with these states only, and this class would be used whenever the Tree class is currently used. But that would create some redundancies, and would probably not be good for readability.

Since these fields currently won't be used if monotone constraints are not used, is it not possible to solve the compatibility problem instead? I am not sure when these issues would arise, but maybe there is a way to specify somewhere that if format don't match, then these fields should remain void, or be filled with dummy values. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlesAuguste I see. could we have these states in tree class, but not save and load them into model file? If this is possible, it is okay as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just double-check the code, it seems the new states are not in model file. So I think it is okay.

};

inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain) {
double left_weight, double right_weight, float gain, bool feature_is_monotone) {
int new_node_idx = num_leaves_ - 1;

// Update if there is a monotone split above the leaf
if (feature_is_monotone || leaf_is_in_monotone_subtree_[leaf]) {
leaf_is_in_monotone_subtree_[leaf] = true;
leaf_is_in_monotone_subtree_[num_leaves_] = true;
}
// update parent info
int parent = leaf_parent_[leaf];
if (parent >= 0) {
Expand All @@ -421,6 +456,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
// add new node
split_feature_inner_[new_node_idx] = feature;
split_feature_[new_node_idx] = real_feature;
node_parent_[new_node_idx] = parent;

split_gain_[new_node_idx] = gain;
// add two new leaves
Expand Down Expand Up @@ -529,6 +565,41 @@ inline int Tree::GetLeafByMap(const std::unordered_map<int, double>& feature_val
return ~node;
}

inline int Tree::node_parent(int node_idx) const{
return node_parent_[node_idx];
}

inline int Tree::left_child(int node_idx) const{
return left_child_[node_idx];
}

inline int Tree::right_child(int node_idx) const{
return right_child_[node_idx];
}

inline int Tree::split_feature_inner(int node_idx) const{
return split_feature_inner_[node_idx];
}

inline int Tree::leaf_parent(int node_idx) const{
return leaf_parent_[node_idx];
}

inline uint32_t Tree::threshold_in_bin(int node_idx) const{
#ifdef DEBUG
CHECK(node_idx >= 0);
#endif
return threshold_in_bin_[node_idx];
}

inline bool Tree::leaf_is_in_monotone_subtree(int leaf_idx) const {
return leaf_is_in_monotone_subtree_[leaf_idx];
}

inline double Tree::internal_value(int node_idx) const {
return internal_value_[node_idx];
}


} // namespace LightGBM

Expand Down
6 changes: 6 additions & 0 deletions src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
}
}
Expand All @@ -548,6 +551,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
}
}
Expand Down
24 changes: 24 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* This file is auto generated by LightGBM\helpers\parameter_generator.py from LightGBM\include\LightGBM\config.h file.
*/
#include<LightGBM/config.h>
#include <LightGBM/utils/log.h>
namespace LightGBM {
std::unordered_map<std::string, std::string> Config::alias_table({
{"config_file", "config"},
Expand Down Expand Up @@ -80,6 +81,8 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"lambda", "lambda_l2"},
{"min_split_gain", "min_gain_to_split"},
{"rate_drop", "drop_rate"},
{"monotone_splits_penalty", "monotone_penalty"},
{"monotone_constraints_precise_mode", "monotone_precise_mode"},
{"topk", "top_k"},
{"mc", "monotone_constraints"},
{"monotone_constraint", "monotone_constraints"},
Expand Down Expand Up @@ -199,6 +202,8 @@ std::unordered_set<std::string> Config::parameter_set({
"lambda_l2",
"min_gain_to_split",
"drop_rate",
"monotone_penalty",
"monotone_precise_mode",
"max_drop",
"skip_drop",
"xgboost_dart_mode",
Expand Down Expand Up @@ -399,8 +404,21 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str, ',');
Log::Warning("The constraining method was just changed, which could significantly affect results of the algorithm");
}

GetDouble(params, "monotone_penalty", &monotone_penalty);
bool constraints_exist = false;
for (auto it = monotone_constraints.begin(); it != monotone_constraints.end();
it++) {
if (*it != 0) {
constraints_exist = true;
}
}
CHECK(monotone_penalty == 0 || constraints_exist);
CHECK(max_depth <= 0 || monotone_penalty < max_depth);
CHECK(monotone_penalty >= 0.0);

if (GetString(params, "feature_contri", &tmp_str)) {
feature_contri = Common::StringToArray<double>(tmp_str, ',');
}
Expand Down Expand Up @@ -476,6 +494,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetBool(params, "use_missing", &use_missing);

GetBool(params, "monotone_precise_mode", &monotone_precise_mode);
CHECK(!monotone_precise_mode || !use_missing);
CHECK(!monotone_precise_mode || constraints_exist);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is auto generated, you cannot edit it by hand.
You could have some post-processing in config.cpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand, I locally tested if these checks worked and they did. What do you mean by the file being auto generated? I am not very familiar with how the library is built I am sorry.

GetBool(params, "zero_as_missing", &zero_as_missing);

GetBool(params, "two_round", &two_round);
Expand Down Expand Up @@ -607,6 +629,8 @@ std::string Config::SaveMembersToString() const {
str_buf << "[lambda_l2: " << lambda_l2 << "]\n";
str_buf << "[min_gain_to_split: " << min_gain_to_split << "]\n";
str_buf << "[drop_rate: " << drop_rate << "]\n";
str_buf << "[monotone_penalty: " << monotone_penalty << "]\n";
str_buf << "[monotone_precise_mode: " << monotone_precise_mode << "]\n";
str_buf << "[max_drop: " << max_drop << "]\n";
str_buf << "[skip_drop: " << skip_drop << "]\n";
str_buf << "[xgboost_dart_mode: " << xgboost_dart_mode << "]\n";
Expand Down
23 changes: 17 additions & 6 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Tree::Tree(int max_leaves)
threshold_.resize(max_leaves_ - 1);
decision_type_.resize(max_leaves_ - 1, 0);
split_gain_.resize(max_leaves_ - 1);
node_parent_.resize(max_leaves_ - 1);
leaf_parent_.resize(max_leaves_);
leaf_is_in_monotone_subtree_.resize(max_leaves_);
leaf_value_.resize(max_leaves_);
leaf_weight_.resize(max_leaves_);
leaf_count_.resize(max_leaves_);
Expand All @@ -38,6 +40,7 @@ Tree::Tree(int max_leaves)
leaf_value_[0] = 0.0f;
leaf_weight_[0] = 0.0f;
leaf_parent_[0] = -1;
node_parent_[0] = -1;
shrinkage_ = 1.0f;
num_cat_ = 0;
cat_boundaries_.push_back(0);
Expand All @@ -50,8 +53,11 @@ Tree::~Tree() {

int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int left_cnt, int right_cnt, double left_weight,
double right_weight, float gain,
MissingType missing_type, bool default_left,
bool feature_was_monotone) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain, feature_was_monotone);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Expand All @@ -69,10 +75,15 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
return num_leaves_ - 1;
}

int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int Tree::SplitCategorical(int leaf, int feature, int real_feature,
const uint32_t *threshold_bin, int num_threshold_bin,
const uint32_t *threshold, int num_threshold,
double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt,
double left_weight, double right_weight,
float gain, MissingType missing_type,
bool feature_was_monotone) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain, feature_was_monotone);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
Expand Down
4 changes: 4 additions & 0 deletions src/treelearner/cost_effective_gradient_boosting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class CostEfficientGradientBoosting {
}
}

SplitInfo const & GetSplitInfo(int i) const {
return splits_per_leaf_[i];
}

private:
double CalculateOndemandCosts(int feature_index, int real_fidx, int leaf_index) const {
if (tree_learner_->config_->cegb_penalty_feature_lazy.empty()) {
Expand Down
31 changes: 16 additions & 15 deletions src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}

template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
TREELEARNER_T::ConstructHistograms(this->is_feature_used_, true);
// construct local histograms
#pragma omp parallel for schedule(static)
Expand All @@ -160,11 +160,12 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
// Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(HistogramBinEntry), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true, tree);
}

template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(
const std::vector<int8_t> &, bool, const Tree *tree) {
std::vector<SplitInfo> smaller_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<SplitInfo> larger_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<int8_t> smaller_node_used_features(this->num_features_, 1);
Expand All @@ -190,13 +191,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split;
// find best threshold for smaller child
// FIXME Fill the vectors with the actual constraints and thresholds
SplittingConstraints *constraints;
std::vector<uint32_t> thresholds;
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_split);
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
&smaller_split, constraints);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
Expand All @@ -210,13 +212,12 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->smaller_leaf_histogram_array_[feature_index]);
SplitInfo larger_split;
// find best threshold for larger child
// FIXME Fill the vectors with the actual constraints and thresholds
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold(
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_split);
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
&larger_split, constraints);
larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_bests_per_thread[tid] = larger_split;
Expand Down
Loading