Commit 1de31581 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Working callback callable interface w/ Q# QIR



Passing callable (callback) from QCOR C++ to Q#; marshal the tuple params: simple types and array type.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 096163d1
Loading
Loading
Loading
Loading
+25 −25
Original line number Diff line number Diff line
@@ -7,15 +7,16 @@ open QCOR.Intrinsic;
// Returns the final energy.
operation DeuteronVqe(shots: Int, stepper : ((Double, Double[]) => Double[])) : Double {
    // Stopping conditions:
    let max_iters = 100;
    let max_iters = 10;
    let f_tol = 0.01;
    let initial_params = [0.0];
    let initial_params = [1.23];
        
    mutable opt_params = initial_params;

    mutable numParityOnes = 0;
    use (qubits = Qubit[2])
    mutable energy_val = 0.0;
    use qubits = Qubit[2]
    {
        for iter_id in 1..max_iters {
            mutable numParityOnes = 0;
            for test in 1..shots {
                X(qubits[0]);
                Ry(opt_params[0], qubits[1]);
@@ -27,17 +28,16 @@ operation DeuteronVqe(shots: Int, stepper : ((Double, Double[]) => Double[])) :
                {
                    set numParityOnes += 1;
                }
            if M(qubits[0]) == One {
                X(qubits[0]);
            }
            if M(qubits[1]) == One {
                X(qubits[1]);
                Reset(qubits[0]);
                Reset(qubits[1]);
            }
            set energy_val =  IntAsDouble(shots - numParityOnes)/IntAsDouble(shots) - IntAsDouble(numParityOnes)/IntAsDouble(shots);
            // Stepping...
            set opt_params = stepper(energy_val, opt_params);
        } 
    }
    let res =  IntAsDouble(shots - numParityOnes)/IntAsDouble(shots) - IntAsDouble(numParityOnes)/IntAsDouble(shots);

    set opt_params = stepper(res, opt_params);
    return res;
    // Final energy:
    return energy_val;
}
}
 No newline at end of file
+69 −8
Original line number Diff line number Diff line
@@ -5,6 +5,64 @@
// Include the external QSharp function.
qcor_include_qsharp(QCOR__DeuteronVqe__body, double, int64_t shots, Callable* opt_stepper);

// Implement of a callback for Q# via IFunctor interface.
// TODO: this is a rigid impl. for protityping,
// we will handle generic callback signature transformation.
class vqe_callback : public qsharp::IFunctor {
public:
  virtual void execute(TuplePtr args, TuplePtr result) override {
    auto next = unpack(args, m_costVal);
    next = unpack(next, m_previousParams);
    auto _result = internal_execute();
    auto test = pack(result, _result);
  }
  vqe_callback(
      std::function<std::vector<double>(double, std::vector<double>)> functor)
      : m_functor(functor) {}

private:
  std::vector<double> internal_execute() {
    std::vector<double> result = m_functor(m_costVal, m_previousParams);
    return result;
  }

  TuplePtr pack(TuplePtr io_tuple, const std::vector<double> &in_vec) {
    ::Array *qirArray = new ::Array(in_vec.size(), sizeof(double));
    for (size_t i = 0; i < in_vec.size(); ++i) {
      auto dest = qirArray->getItemPointer(i);
      auto src = &in_vec[i];
      memcpy(dest, src, sizeof(double));
    }

    TupleHeader *th = ::TupleHeader::getHeader(io_tuple);
    memcpy(io_tuple, &qirArray, sizeof(::Array *));
    return io_tuple + sizeof(::Array *);
  }

  TuplePtr unpack(TuplePtr in_tuple, double &out_val) {
    out_val = *(reinterpret_cast<double *>(in_tuple));
    return in_tuple + sizeof(double);
  }

  TuplePtr unpack(TuplePtr in_tuple, std::vector<double> &out_val) {
    out_val.clear();
    ::Array *arrayPtr = *(reinterpret_cast<::Array **>(in_tuple));
    // std::cout << "Array of size " << arrayPtr->size()
    //           << "; element size = " << arrayPtr->element_size() << "\n";
    for (size_t i = 0; i < arrayPtr->size(); ++i) {
      const double el =
          *(reinterpret_cast<double *>(arrayPtr->getItemPointer(i)));
      out_val.emplace_back(el);
    }
    return in_tuple + sizeof(::Array *);
  }

private:
  double m_costVal = 0.0;
  std::vector<double> m_previousParams;
  std::function<std::vector<double>(double, std::vector<double>)> m_functor;
};

// Compile with:
// Include both the qsharp source and this driver file
// in the command line.
@@ -13,14 +71,17 @@ qcor_include_qsharp(QCOR__DeuteronVqe__body, double, int64_t shots, Callable* op
// $ ./a.out
int main() {
  std::function<std::vector<double>(double, std::vector<double>)> stepper =
      [&](double in_costVal, std::vector<double> previous_params) -> std::vector<double> {
      [&](double in_costVal,
          std::vector<double> previous_params) -> std::vector<double> {
    std::cout << "HELLO CALLBACK!\n";
        return {1.0};
    std::cout << "Cost value = " << in_costVal << "\n";
    return {previous_params[0] + 0.5};
  };
  qcor::qsharp::CallBack<std::vector<double>, double, std::vector<double>> cbFunc(
      stepper);

  Callable cb(&cbFunc);
  vqe_callback test(stepper);
  // Create a QIR callable
  Callable cb(&test);

  const double exp_val_xx = QCOR__DeuteronVqe__body(1024, &cb);
  return 0;
}
 No newline at end of file
+5 −13
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@
#include <cassert>
#include <stdexcept>
#include <functional>
#include <cstring>
#include <iostream>
// Defines implementations of QIR Opaque types

// FIXME - Qubit should be a struct that keeps track of idx
@@ -120,10 +122,10 @@ class IFunctor;
}
} // namespace qcor

// QIR Callable implementation.
struct Callable {
  void invoke(TuplePtr args, TuplePtr result);
  Callable(qcor::qsharp::IFunctor *in_functor) : m_functor(in_functor) {}

private:
  qcor::qsharp::IFunctor *m_functor;
};
@@ -133,21 +135,11 @@ namespace qcor {
std::vector<int64_t> getRangeValues(::Array *in_array, const ::Range &in_range);

namespace qsharp {
// A generic base class of qcor function-like objects
// that will be invoked by Q# as a callable.
class IFunctor {
public:
  virtual void execute(TuplePtr args, TuplePtr result) = 0;
};
template <typename ReturnType, typename... ArgTypes>
class CallBack : public IFunctor {
public:
  CallBack(std::function<ReturnType(ArgTypes...)> in_func) : m_func(in_func){};
  virtual void execute(TuplePtr args, TuplePtr result) override {
    printf("Howdy callable\n");
    printf(__PRETTY_FUNCTION__);
  }

private:
  std::function<ReturnType(ArgTypes...)> m_func;
};
} // namespace qsharp
} // namespace qcor