Commit ad3396c6 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Use the gradient func in ObjFunc workflow



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 443a21b1
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -7,11 +7,12 @@ namespace qcor {
namespace __internal__ {
std::shared_ptr<GradientFunction>
get_gradient_method(const std::string &type,
                    std::shared_ptr<ObjectiveFunction> obj_func) {
                    std::shared_ptr<ObjectiveFunction> obj_func,
                    xacc::HeterogeneousMap options) {
  if (!xacc::isInitialized())
    xacc::internal_compiler::compiler_InitializeXACC();
  auto service = xacc::getService<KernelGradientService>(type);
  service->initialize(obj_func);
  service->initialize(obj_func, std::move(options));
  return service;
}
} // namespace __internal__
+3 −1
Original line number Diff line number Diff line
@@ -30,9 +30,11 @@ public:
};

namespace __internal__ {
const std::string DEFAULT_GRADIENT_METHOD = "central";
std::shared_ptr<GradientFunction>
get_gradient_method(const std::string &type,
                    std::shared_ptr<ObjectiveFunction> obj_func);
                    std::shared_ptr<ObjectiveFunction> obj_func,
                    xacc::HeterogeneousMap options = {});
} // namespace __internal__

// Interface for gradient calculation services.
+19 −1
Original line number Diff line number Diff line
@@ -321,13 +321,31 @@ public:
    kernel = kernel_evaluator(x);
    helper->update_kernel(kernel);

    // Save the input dx:
    const auto input_dx = dx;
    
    auto cost_val = (*helper)(qreg, dx);
    // If we needs gradients:
    if (!dx.empty()) {
    // the optimizer requires dx (not empty)
    // and the concrete ObjFunc sub-class doesn't calculate the gradients.
    if (!dx.empty() && input_dx == dx) {
      if (dx.size() != x.size()) {
        error("Dimension mismatched: gradients and parameters vectors have "
              "different size.");
      }

      if (!gradiend_method) {
        std::string gradient_method_name =
            qcor::__internal__::DEFAULT_GRADIENT_METHOD;
        // Backward compatible:
        // If the "gradient-strategy" was specified in the option.
        if (options.stringExists("gradient-strategy")) {
          gradient_method_name = options.getString("gradient-strategy");
        }
        gradiend_method = qcor::__internal__::get_gradient_method(
            gradient_method_name, xacc::as_shared_ptr(this), options);
      }

      dx = (*gradiend_method)(x, cost_val);
    }
    return cost_val;
+21 −21
Original line number Diff line number Diff line
@@ -92,27 +92,27 @@ class VQEObjective : public ObjectiveFunction {
    current_iteration++;
    // qreg.addChild(tmp_child);

    if (!dx.empty() && options.stringExists("gradient-strategy")) {
      // Compute the gradient
      auto gradient_strategy =
          xacc::getService<xacc::AlgorithmGradientStrategy>(
              options.getString("gradient-strategy"));

      if (gradient_strategy->isNumerical() &&
          observable->getIdentitySubTerm()) {
        gradient_strategy->setFunctionValue(
            val - std::real(observable->getIdentitySubTerm()->coefficient()));
      }

      gradient_strategy->initialize(options);
      auto grad_kernels = gradient_strategy->getGradientExecutions(
          kernel, current_iterate_parameters);

      auto tmp_grad = qalloc(qreg.size());
      qpu->execute(xacc::as_shared_ptr(tmp_grad.results()), grad_kernels);
      auto tmp_grad_children = tmp_grad.results()->getChildren();
      gradient_strategy->compute(dx, tmp_grad_children);
    }
    // if (!dx.empty() && options.stringExists("gradient-strategy")) {
    //   // Compute the gradient
    //   auto gradient_strategy =
    //       xacc::getService<xacc::AlgorithmGradientStrategy>(
    //           options.getString("gradient-strategy"));

    //   if (gradient_strategy->isNumerical() &&
    //       observable->getIdentitySubTerm()) {
    //     gradient_strategy->setFunctionValue(
    //         val - std::real(observable->getIdentitySubTerm()->coefficient()));
    //   }

    //   gradient_strategy->initialize(options);
    //   auto grad_kernels = gradient_strategy->getGradientExecutions(
    //       kernel, current_iterate_parameters);

    //   auto tmp_grad = qalloc(qreg.size());
    //   qpu->execute(xacc::as_shared_ptr(tmp_grad.results()), grad_kernels);
    //   auto tmp_grad_children = tmp_grad.results()->getChildren();
    //   gradient_strategy->compute(dx, tmp_grad_children);
    // }
    return val;
  }