Skip to content

Commit

Permalink
added surrogate support to 'sep' algorithms, ref #57
Browse files Browse the repository at this point in the history
  • Loading branch information
Emmanuel Benazera committed Oct 8, 2014
1 parent 5c83a39 commit 0282740
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
8 changes: 7 additions & 1 deletion examples/surrogates/rankingsvm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,13 @@ class RankingSVM
{
for (int i=0;i<x.cols();i++)
x.col(i) -= xmean;
x = covinv * x;
if (covinv.cols() > 1)
x = covinv * x;
else
{
for (int i=0;i<x.cols();i++)
x.col(i) = covinv.cwiseProduct(x.col(i));
}
}

/**
Expand Down
6 changes: 2 additions & 4 deletions examples/surrogates/sample-code-surrogate-rsvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ template<class TCovarianceUpdate=CovarianceUpdate,class TGenoPheno=GenoPheno<NoB
dVec fvalues;
std::vector<Candidate> cp = c;
to_mat_vec(cp,x,fvalues);
//int niter = 1e6;//floor(50000*sqrt(c.at(0).get_x_size()));
dVec xmean = eostrat<TGenoPheno>::get_solutions().xmean();
_rsvm = RankingSVM<RBFKernel>();
_rsvm._encode = true;
_rsvm.train(x,_rsvm_iter,cov,xmean);

this->set_train_error(this->compute_error(cp,
eostrat<TGenoPheno>::get_solutions().csqinv()));
this->set_train_error(this->compute_error(cp,cov));

//debug
//std::cout << "training error=" << _rsvm.error(x,x,fvalues,cov,xmean) << std::endl;
Expand All @@ -95,7 +93,7 @@ template<class TCovarianceUpdate=CovarianceUpdate,class TGenoPheno=GenoPheno<NoB

dVec fit;
dVec xmean = eostrat<TGenoPheno>::get_solutions().xmean();
_rsvm.predict(fit,x_test,x_train,eostrat<TGenoPheno>::get_solutions().csqinv(),xmean);
_rsvm.predict(fit,x_test,x_train,cov,xmean);
if (fit.size() != 0)
for (int i=0;i<(int)c.size();i++)
c.at(i).set_fvalue(fit(i));
Expand Down
38 changes: 26 additions & 12 deletions src/surrogatestrategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ namespace libcmaes

std::sort(ctest_set.begin(),ctest_set.end(),
[](Candidate const &c1, Candidate const &c2){return c1.get_fvalue() > c2.get_fvalue();});

this->predict(ctest_set,cov);

double err = 0.0;
Expand Down Expand Up @@ -134,8 +134,11 @@ namespace libcmaes
// compute test error if needed.
if (this->_niter != 0 && (int)this->_tset.size() >= this->_l)
{
this->set_test_error(this->compute_error(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::get_solutions().csqinv())); //TODO: sep
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->set_test_error(this->compute_error(eostrat<TGenoPheno>::_solutions._candidates,
eostrat<TGenoPheno>::_solutions._csqinv));
else this->set_test_error(this->compute_error(eostrat<TGenoPheno>::_solutions._candidates,
eostrat<TGenoPheno>::_solutions._sepcsqinv));
}

// use original objective function and collect points.
Expand All @@ -149,8 +152,11 @@ namespace libcmaes
for (int r=0;r<candidates.cols();r++)
{
eostrat<TGenoPheno>::get_solutions().get_candidate(r).set_x(candidates.col(r));
this->predict(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::get_solutions().csqinv());
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->predict(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::get_solutions().csqinv());
else this->predict(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::get_solutions().sepcsqinv());
}
}
}
Expand Down Expand Up @@ -227,8 +233,11 @@ namespace libcmaes
// compute test error if needed.
if (this->_niter != 0 && (int)this->_tset.size() >= this->_l)
{
this->set_test_error(this->compute_error(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::_solutions._csqinv)); //TODO: sep
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->set_test_error(this->compute_error(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::_solutions._csqinv));
else this->set_test_error(this->compute_error(eostrat<TGenoPheno>::get_solutions().candidates(),
eostrat<TGenoPheno>::_solutions._sepcsqinv));
}

// use original objective function and collect points.
Expand Down Expand Up @@ -261,8 +270,9 @@ namespace libcmaes
// train surrogate as required.
if (do_train())
{
this->train(this->_tset,
eostrat<TGenoPheno>::_solutions._csqinv);
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->train(this->_tset,eostrat<TGenoPheno>::_solutions._csqinv);
else this->train(this->_tset,eostrat<TGenoPheno>::_solutions._sepcsqinv);
}
}

Expand All @@ -273,9 +283,11 @@ namespace libcmaes
eostrat<TGenoPheno>::_solutions._candidates.clear();
for (int r=0;r<candidates.cols();r++)
{
eostrat<TGenoPheno>::_solutions._candidates.push_back(Candidate(0.0,candidates.col(r)));//.at(r).set_x(candidates.col(r));
eostrat<TGenoPheno>::_solutions._candidates.push_back(Candidate(0.0,candidates.col(r)));
}
this->predict(eostrat<TGenoPheno>::_solutions._candidates,eostrat<TGenoPheno>::_solutions._csqinv); // XXX: prediction from genotype since learning from genotype!
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->predict(eostrat<TGenoPheno>::_solutions._candidates,eostrat<TGenoPheno>::_solutions._csqinv);
else this->predict(eostrat<TGenoPheno>::_solutions._candidates,eostrat<TGenoPheno>::_solutions._sepcsqinv);
eostrat<TGenoPheno>::_solutions.sort_candidates();

// - draw 'a'<lambda_pre samples according to lambda_pre*N(0,theta_sel0^2) and retain each sample from initial population, with rank r < floor(a)
Expand Down Expand Up @@ -320,7 +332,9 @@ namespace libcmaes
}

// test error.
this->set_test_error(this->compute_error(test_set,eostrat<TGenoPheno>::_solutions._csqinv));
if (!eostrat<TGenoPheno>::_parameters.is_sep())
this->set_test_error(this->compute_error(test_set,eostrat<TGenoPheno>::_solutions._csqinv));
else this->set_test_error(this->compute_error(test_set,eostrat<TGenoPheno>::_solutions._sepcsqinv));

// set candidate set.
eostrat<TGenoPheno>::_solutions._candidates = ncandidates;
Expand Down

0 comments on commit 0282740

Please sign in to comment.