Commit 6cac1dda authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on supporting Q# control-adjoint lowering



The Q# compiler will create callable with 4 different function pointers and we need to track which one should be applied at invoke.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 35fed760
Loading
Loading
Loading
Loading
+7 −10
Original line number Diff line number Diff line
@@ -161,18 +161,15 @@ void __quantum__rt__capture_update_reference_count(Callable *clb,
void __quantum__rt__capture_update_alias_count(Callable *clb, int32_t count);
void __quantum__rt__callable_memory_management(int32_t index, Callable *clb,
                                               int64_t parameter);
Callable *__quantum__rt__callable_make_adjoint(Callable *clb);
Callable *__quantum__rt__callable_make_controlled(Callable *clb);
void __quantum__rt__callable_make_adjoint(Callable *clb);
void __quantum__rt__callable_make_controlled(Callable *clb);
// Implementation table: 4x callables of a specific signature
typedef struct impl_table_t {
  void (*f[4])(TuplePtr, TuplePtr, TuplePtr);
} impl_table_t;
typedef struct mem_management_cb_t {
  void (*f[2])(TuplePtr, int64_t);
} mem_management_cb_t;
// Create callable (from Q#): 
// See spec: https://github.com/microsoft/qsharp-language/blob/main/Specifications/QIR/Callables.md
Callable* __quantum__rt__callable_create(impl_table_t* ft, mem_management_cb_t* callbacks, TuplePtr capture);
Callable *
__quantum__rt__callable_create(Callable::CallableEntryType *ft,
                               Callable::CaptureCallbackType *callbacks,
                               TuplePtr capture);
// Classical Runtime:
// https://github.com/microsoft/qsharp-language/blob/main/Specifications/QIR/Classical-Runtime.md#classical-runtime
void __quantum__rt__fail(QirString *str);
+47 −2
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
#include <stdexcept>
#include <unordered_map>
#include <vector>

#include <cstring>
// Defines implementations of QIR Opaque types

namespace qcor {
@@ -235,11 +235,56 @@ class IFunctor;

// QIR Callable implementation.
struct Callable {
  // Typedef's and constants
  typedef void (*CallableEntryType)(TuplePtr, TuplePtr, TuplePtr);
  typedef void (*CaptureCallbackType)(TuplePtr, int32_t);
  static int constexpr AdjointIdx = 1;
  static int constexpr ControlledIdx = 1 << 1;
  static int constexpr TableSize = 4;
  static int constexpr CaptureCallbacksTableSize = 2;
  // =======================================================

  void invoke(TuplePtr args, TuplePtr result);
  // Constructor from C++ functor
  Callable(qcor::qsharp::IFunctor *in_functor) : m_functor(in_functor) {}
  Callable(CallableEntryType *ftEntries, CaptureCallbackType *captureCallbacks,
           TuplePtr capture) {
    memcpy(m_functionTable, ftEntries, sizeof(m_functionTable));
    if (m_functionTable[0] == nullptr) {
      throw "Base functor must be defined.";
    }
    if (captureCallbacks != nullptr) {
      memcpy(m_captureCallbacks, captureCallbacks,
             sizeof(this->m_captureCallbacks));
    }
    m_capture = capture;
  }

  // Add arbitrary nested layer of control/adjoint
  // A + A = I; A + C = C + A = CA; C + C = C; CA + A = C; CA + C = CA
  void applyFunctor(int functorIdx) {
    if (functorIdx == Callable::AdjointIdx) {
      m_functorIdx ^= Callable::AdjointIdx;
      if (m_functionTable[m_functorIdx] == nullptr) {
        throw "The Callable doesn't have Adjoint implementation.";
      }
    }
    if (functorIdx == Callable::ControlledIdx) {
      m_functorIdx |= Callable::ControlledIdx;
      if (m_functionTable[m_functorIdx]) {
        throw "The Callable doesn't have Controlled implementation.";
      }
      m_controlledDepth++;
    }
  }

private:
  qcor::qsharp::IFunctor *m_functor;
  qcor::qsharp::IFunctor *m_functor = nullptr;
  CallableEntryType m_functionTable[TableSize] = {nullptr, nullptr, nullptr, nullptr};
  CaptureCallbackType m_captureCallbacks[CaptureCallbacksTableSize] = {nullptr, nullptr};
  TuplePtr m_capture = nullptr;
  int m_functorIdx = 0;
  int m_controlledDepth = 0;
};

// QIR string type (regular string with ref. counting)
+22 −6
Original line number Diff line number Diff line
@@ -41,13 +41,29 @@ void __quantum__rt__callable_memory_management(int32_t index, Callable *clb,
    std::cout << "CALL: " << __PRETTY_FUNCTION__ << "\n";
}

Callable *__quantum__rt__callable_make_adjoint(Callable *clb) {
void __quantum__rt__callable_make_adjoint(Callable *clb) {
  if (verbose)
    std::cout << "CALL: " << __PRETTY_FUNCTION__ << "\n";
  if (clb == nullptr) {
    return;
  }
  clb->applyFunctor(Callable::AdjointIdx);
}
Callable *__quantum__rt__callable_make_controlled(Callable *clb) {
void __quantum__rt__callable_make_controlled(Callable *clb) {
  if (verbose)
    std::cout << "CALL: " << __PRETTY_FUNCTION__ << "\n";
  if (clb == nullptr) {
    return;
  }
Callable* __quantum__rt__callable_create(impl_table_t* ft, mem_management_cb_t* callbacks, TuplePtr capture) {
  clb->applyFunctor(Callable::ControlledIdx);
}
Callable *
__quantum__rt__callable_create(Callable::CallableEntryType *ft,
                               Callable::CaptureCallbackType *callbacks,
                               TuplePtr capture) {
  if (verbose)
    std::cout << "CALL: " << __PRETTY_FUNCTION__ << "\n";
  auto clb = new Callable(ft, callbacks, capture);
  return clb;
}
}
 No newline at end of file
+10 −0
Original line number Diff line number Diff line
@@ -51,6 +51,16 @@ TuplePtr __quantum__rt__tuple_copy(int8_t *tuple, bool forceNewInstance) {
void Callable::invoke(TuplePtr args, TuplePtr result) {
  if (m_functor) {
    m_functor->execute(args, result);
    return;
  }
  if (m_functionTable[m_functorIdx]) {
    if (m_controlledDepth < 2) {
      m_functionTable[m_functorIdx](m_capture, args, result);
    }
    else {
      // TODO: flatten the control array.
      throw "Multi-controlled is not supported yet.";
    }
  }
}
}
 No newline at end of file