Commit a022f61f authored by Mccaskey, Alex's avatar Mccaskey, Alex

cleaning up isGradientBased and get algorithm impls on the nlopt and mlpack optimizers

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent af2c163a
Pipeline #109285 passed with stage
in 73 minutes and 43 seconds
......@@ -65,32 +65,9 @@ public:
throw std::bad_function_call();
}
const std::string get_algorithm() const {
virtual const std::string get_algorithm() const {return "";}
if(name() == "nlopt" && !options.stringExists(name() + "-optimizer")){
return "cobyla";
} else if (name() == "mlpack" && !options.stringExists(name() + "-optimizer")){
return "adam";
} else {
return options.getString(name() + "-optimizer");
}
}
const bool isGradientBased() const {
if(name() == "nlopt" && !options.stringExists(name() + "-optimizer")){
return false;
}
if(options.getString(name() + "-optimizer") == "cobyla" ||
options.getString(name() + "-optimizer") == "nelder-mead") {
return false;
}
return true;
}
virtual const bool isGradientBased() const {return false;}
};
} // namespace xacc
......
......@@ -18,6 +18,28 @@
using namespace ens;
namespace xacc {
const bool MLPACKOptimizer::isGradientBased() const {
std::string mlpack_opt_name = "adam";
if (options.stringExists("mlpack-optimizer")) {
mlpack_opt_name = options.getString("mlpack-optimizer");
}
std::vector<std::string> non_grad{"cmaes", "spsa"};
if (xacc::container::contains(non_grad, mlpack_opt_name)) {
return false;
} else {
return true;
}
}
const std::string MLPACKOptimizer::get_algorithm() const {
std::string mlpack_opt_name = "adam";
if (options.stringExists("mlpack-optimizer")) {
mlpack_opt_name = options.getString("mlpack-optimizer");
}
return mlpack_opt_name;
}
OptResult MLPACKOptimizer::optimize(OptFunction &function) {
......
......@@ -68,6 +68,8 @@ class MLPACKOptimizer : public Optimizer {
public:
MLPACKOptimizer() = default;
OptResult optimize(OptFunction &function) override;
const bool isGradientBased() const override;
const std::string get_algorithm() const override;
const std::string name() const override { return "mlpack"; }
const std::string description() const override { return ""; }
......
......@@ -25,6 +25,28 @@ double c_wrapper(const std::vector<double> &x, std::vector<double> &grad,
return e->f(x, grad);
}
const std::string NLOptimizer::get_algorithm() const {
std::string nlopt_opt_name = "cobyla";
if (options.stringExists("nlopt-optimizer")) {
nlopt_opt_name = options.getString("nlopt-optimizer");
}
return nlopt_opt_name;
}
const bool NLOptimizer::isGradientBased() const {
std::string nlopt_opt_name = "cobyla";
if (options.stringExists("nlopt-optimizer")) {
nlopt_opt_name = options.getString("nlopt-optimizer");
}
if (nlopt_opt_name == "l-bfgs") {
return true;
} else {
return false;
}
}
OptResult NLOptimizer::optimize(OptFunction &function) {
auto dim = function.dimensions();
......@@ -53,7 +75,8 @@ OptResult NLOptimizer::optimize(OptFunction &function) {
if (options.keyExists<int>("nlopt-maxeval")) {
maxeval = options.get<int>("nlopt-maxeval");
xacc::info("[NLOpt] max function evaluations set to " + std::to_string(maxeval));
xacc::info("[NLOpt] max function evaluations set to " +
std::to_string(maxeval));
}
std::vector<double> x(dim);
......@@ -76,7 +99,9 @@ OptResult NLOptimizer::optimize(OptFunction &function) {
_opt.set_ftol_rel(tol);
if (dim != x.size()) {
xacc::error("Invalid optimization configuration: function dim == " + std::to_string(dim) + ", param_size == " + std::to_string(x.size()));
xacc::error("Invalid optimization configuration: function dim == " +
std::to_string(dim) +
", param_size == " + std::to_string(x.size()));
}
double optF;
nlopt::result r;
......
......@@ -27,6 +27,8 @@ struct ExtraNLOptData {
class NLOptimizer : public Optimizer {
public:
OptResult optimize(OptFunction &function) override;
const bool isGradientBased() const override;
virtual const std::string get_algorithm() const;
const std::string name() const override { return "nlopt"; }
const std::string description() const override { return ""; }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment