Skip to content

Commit

Permalink
Fix loading old model.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 28, 2020
1 parent 91c6463 commit 4fdee3e
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,24 @@ class LearnerIO : public LearnerConfiguration {
warn_old_model = false;
}

if (mparam_.major_version >= 1) {
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
} else {
if (mparam_.major_version < 1) {
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
// attribute in the saved model.
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
std::string multi{"multi:"};
if (!std::equal(tparam_.objective.cbegin(), tparam_.objective.cend(),
multi.begin())) {
HostDeviceVector<float> t;
t.HostVector().resize(1);
t.HostVector().at(0) = mparam_.base_score;
this->obj_->PredTransform(&t);
auto base_score = t.HostVector().at(0);
mparam_.base_score = base_score;
}
warn_old_model = true;
}

learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
Expand Down Expand Up @@ -752,6 +761,7 @@ class LearnerIO : public LearnerConfiguration {
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr;
mparam.contain_extra_attrs = 1;
std::cout << "mparam.base_score:" << mparam.base_score << std::endl;

{
std::vector<std::string> saved_params;
Expand Down

0 comments on commit 4fdee3e

Please sign in to comment.