diff --git a/examples/surrogates/rankingsvm.hpp b/examples/surrogates/rankingsvm.hpp index fd600ed2..a4da85f8 100644 --- a/examples/surrogates/rankingsvm.hpp +++ b/examples/surrogates/rankingsvm.hpp @@ -218,7 +218,13 @@ class RankingSVM { for (int i=0;i 1) + x = covinv * x; + else + { + for (int i=0;i cp = c; to_mat_vec(cp,x,fvalues); -//int niter = 1e6;//floor(50000*sqrt(c.at(0).get_x_size())); dVec xmean = eostrat::get_solutions().xmean(); _rsvm = RankingSVM(); _rsvm._encode = true; _rsvm.train(x,_rsvm_iter,cov,xmean); - this->set_train_error(this->compute_error(cp, - eostrat::get_solutions().csqinv())); + this->set_train_error(this->compute_error(cp,cov)); //debug //std::cout << "training error=" << _rsvm.error(x,x,fvalues,cov,xmean) << std::endl; @@ -95,7 +93,7 @@ template::get_solutions().xmean(); - _rsvm.predict(fit,x_test,x_train,eostrat::get_solutions().csqinv(),xmean); + _rsvm.predict(fit,x_test,x_train,cov,xmean); if (fit.size() != 0) for (int i=0;i<(int)c.size();i++) c.at(i).set_fvalue(fit(i)); diff --git a/src/surrogatestrategy.cc b/src/surrogatestrategy.cc index ebdd91ff..471375d1 100644 --- a/src/surrogatestrategy.cc +++ b/src/surrogatestrategy.cc @@ -89,7 +89,7 @@ namespace libcmaes std::sort(ctest_set.begin(),ctest_set.end(), [](Candidate const &c1, Candidate const &c2){return c1.get_fvalue() > c2.get_fvalue();}); - + this->predict(ctest_set,cov); double err = 0.0; @@ -134,8 +134,11 @@ namespace libcmaes // compute test error if needed. if (this->_niter != 0 && (int)this->_tset.size() >= this->_l) { - this->set_test_error(this->compute_error(eostrat::get_solutions().candidates(), - eostrat::get_solutions().csqinv())); //TODO: sep + if (!eostrat::_parameters.is_sep()) + this->set_test_error(this->compute_error(eostrat::_solutions._candidates, + eostrat::_solutions._csqinv)); + else this->set_test_error(this->compute_error(eostrat::_solutions._candidates, + eostrat::_solutions._sepcsqinv)); } // use original objective function and collect points. @@ -149,8 +152,11 @@ namespace libcmaes for (int r=0;r::get_solutions().get_candidate(r).set_x(candidates.col(r)); - this->predict(eostrat::get_solutions().candidates(), - eostrat::get_solutions().csqinv()); + if (!eostrat::_parameters.is_sep()) + this->predict(eostrat::get_solutions().candidates(), + eostrat::get_solutions().csqinv()); + else this->predict(eostrat::get_solutions().candidates(), + eostrat::get_solutions().sepcsqinv()); } } } @@ -227,8 +233,11 @@ namespace libcmaes // compute test error if needed. if (this->_niter != 0 && (int)this->_tset.size() >= this->_l) { - this->set_test_error(this->compute_error(eostrat::get_solutions().candidates(), - eostrat::_solutions._csqinv)); //TODO: sep + if (!eostrat::_parameters.is_sep()) + this->set_test_error(this->compute_error(eostrat::get_solutions().candidates(), + eostrat::_solutions._csqinv)); + else this->set_test_error(this->compute_error(eostrat::get_solutions().candidates(), + eostrat::_solutions._sepcsqinv)); } // use original objective function and collect points. @@ -261,8 +270,9 @@ namespace libcmaes // train surrogate as required. if (do_train()) { - this->train(this->_tset, - eostrat::_solutions._csqinv); + if (!eostrat::_parameters.is_sep()) + this->train(this->_tset,eostrat::_solutions._csqinv); + else this->train(this->_tset,eostrat::_solutions._sepcsqinv); } } @@ -273,9 +283,11 @@ namespace libcmaes eostrat::_solutions._candidates.clear(); for (int r=0;r::_solutions._candidates.push_back(Candidate(0.0,candidates.col(r)));//.at(r).set_x(candidates.col(r)); + eostrat::_solutions._candidates.push_back(Candidate(0.0,candidates.col(r))); } - this->predict(eostrat::_solutions._candidates,eostrat::_solutions._csqinv); // XXX: prediction from genotype since learning from genotype! + if (!eostrat::_parameters.is_sep()) + this->predict(eostrat::_solutions._candidates,eostrat::_solutions._csqinv); + else this->predict(eostrat::_solutions._candidates,eostrat::_solutions._sepcsqinv); eostrat::_solutions.sort_candidates(); // - draw 'a'set_test_error(this->compute_error(test_set,eostrat::_solutions._csqinv)); + if (!eostrat::_parameters.is_sep()) + this->set_test_error(this->compute_error(test_set,eostrat::_solutions._csqinv)); + else this->set_test_error(this->compute_error(test_set,eostrat::_solutions._sepcsqinv)); // set candidate set. eostrat::_solutions._candidates = ncandidates;