Commit 424d2e95 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update to prototype qasm3 handler to support qubits, qregs, qalloc (all in...


update to prototype qasm3 handler to support qubits, qregs, qalloc (all in qcor namespace), as well as vector<double> kernel args

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 4e0568b0
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -29,6 +29,11 @@ def ExtractQubitOp : QuantumOp<"qextract", []> {
    let results = (outs QubitType:$qbit);
}

def GeneralArrayExtractOp : QuantumOp<"array_extract", []> {
    let arguments = (ins ArrayType:$array, AnyInteger:$idx);
    let results = (outs AnyType:$element);
}

// Assign a qubit pointer (specified by the Qubit array and index) to an alias pointer. 
// Signature: void qassign(Array* destination_array, int destination_idx, Array* source_array, int source_idx)
def AssignQubitOp : QuantumOp<"qassign", []> {
+2 −4
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ for i in [0:n_counting] {
}

// Run inverse QFT 
iqft counting;
iqf2 counting;

// Now lets measure the counting qubits
bit c[n_counting];
@@ -58,6 +58,4 @@ measure counting -> c;

// Backend is QPP which is lsb, 
// so return should be 100
for i in [0:n_counting]{
    print(c[i]);
}
print(c);
+39 −22
Original line number Diff line number Diff line
#include "qir-qrt.hpp"
using qubit = Qubit*;
#include <stdio.h>

[[clang::syntax(qasm3)]] int test(int i, qubit q, qubit r) {
  int ten = 10 * i;
#include <cstring>
#define __qasm3__ [[clang::syntax(qasm3)]]

__qasm3__ int test(int i, qubit q, qubit r, qcor::qreg& qq,
                   std::vector<double> xx) {

  // qreg pass by reference until we figure out ref counting on array!!!

  int mult = 10 * i;
  // now do some quantum
  print("test out vector :) ", xx[0]);
  h q;
  h qq[0];  // show using a qreg... this will result in all 00s
  cx q, r;
  bit c[2];
  c[0] = measure q;
@@ -14,38 +23,46 @@ using qubit = Qubit*;
}

// Run with...
// clang++ -std=c++17 -fplugin=/path/to/libqasm3-syntax-handler.so -I /path/to/qcor/include/qcor -I/path/to/xacc/include/xacc -c qasm3_test.cpp
// llc -filetype=obj test.bc (test comes from kernel name, so will need to do this for all kernels)
// clang++ -L /path/to/install/lib -lqir-qrt -lqcor -lxacc -lqrt -lCppMicroServices test.o qasm3_test.o
// clang++ -std=c++17 -fplugin=/path/to/libqasm3-syntax-handler.so -I
// /path/to/qcor/include/qcor -I/path/to/xacc/include/xacc -c qasm3_test.cpp
// llc -filetype=obj test.bc (test comes from kernel name, so will need to do
// this for all kernels) clang++ -L /path/to/install/lib -lqir-qrt -lqcor -lxacc
// -lqrt -lCppMicroServices test.o qasm3_test.o
// ./a.out

int main(int argc, char** argv) {
  int x = 10;

  // Figure out how to initialize automatically
  __quantum__rt__initialize(argc, reinterpret_cast<int8_t**>(argv));
  // Can provide runtime parameters
  // via explicit initialize call.
  qcor::initialize(argc, argv);
  // Otherwise it will be called automatically

  int x = 10, shots = 50;
  std::vector<double> xx{1.2};

  // Connect this to qalloc(...)
  auto qreg = __quantum__rt__qubit_allocate_array(2);
  auto qreg = qcor::qalloc(2);

  // Should be able to get qubit from operator[] on qreg
  auto qbit_mem = __quantum__rt__array_get_element_ptr_1d(qreg, 0);
  auto qbit = reinterpret_cast<Qubit**>(qbit_mem)[0];
  auto qbit_mem2 = __quantum__rt__array_get_element_ptr_1d(qreg, 1);
  auto qbit2 = reinterpret_cast<Qubit**>(qbit_mem2)[0];
  // Can now extract the qubits individually
  auto q = qreg[0];
  auto r = qreg[1];

  // Run bell test...
  int ones = 0, zeros = 0;
  for (int i = 0; i < 50; i++) {
    auto y = test(x, qbit, qbit2);
  for (int i = 0; i < shots; i++) {
    // should be binary-as-int 00 = 0, or 11 = 3
    if (y == 3) {
    xx[0] = (double)i;
    if (test(32, q, r, qreg, xx)) {
      ones++;
    } else {
      zeros++;
    }
  }
  printf("Result: 11:%d, 00:%d\n", ones, zeros);

  assert(zeros == shots);

  printf("Result: {'11':%d, '00':%d}\n", ones, zeros);

  // quantum memory will be freed
  // when it goes out of scope.
  return 0;
}
 No newline at end of file
+65 −15
Original line number Diff line number Diff line
@@ -4,8 +4,8 @@
#include <regex>
#include <sstream>

#include "qasm3_handler_utils.hpp"
#include "openqasmv3_mlir_generator.hpp"
#include "qasm3_handler_utils.hpp"
#include "quantum_to_llvm.hpp"

using namespace clang;
@@ -18,6 +18,12 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
  // Get the function name
  auto kernel_name = D.getName().Identifier->getName().str();

  auto &diagnostics = PP.getDiagnostics();
  auto invalid_kernel_arg = diagnostics.getCustomDiagID(
      clang::DiagnosticsEngine::Fatal,
      "Invalid quantum kernel argument - we do not know how to map this type "
      "to a mlir::Type yet (%0 %1).");

  // Create the MLIRContext and load the dialects
  mlir::MLIRContext context;
  context
@@ -38,10 +44,9 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
  // build up associated mlir::Type arguments,
  // For vectors use Array *
  std::vector<mlir::Type> arg_types;
  std::vector<std::string> program_parameters, arg_type_strs;
  std::vector<std::string> program_parameters, arg_type_strs, var_attributes;
  const DeclaratorChunk::FunctionTypeInfo &FTI = D.getFunctionTypeInfo();
  for (unsigned int ii = 0; ii < FTI.NumParams; ii++) {

    // Get parameters as a ParmVarDecl
    auto &paramInfo = FTI.Params[ii];
    auto &decl = paramInfo.Param;
@@ -50,19 +55,33 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
    auto type = parm_var_decl->getType().getTypePtr();

    // Get VarName and Type as strings
    Token IdentToken, TypeToken;
    Token IdentToken, TypeToken, test;
    PP.getRawToken(paramInfo.IdentLoc, IdentToken);
    PP.getRawToken(decl->getBeginLoc(), TypeToken);

    auto var = PP.getSpelling(IdentToken);
    auto type_str = PP.getSpelling(TypeToken);

    // Convert type to a mlir type
    mlir::Type t = convertClangType(type, type_str, context);
    if (!t) {
      auto db = diagnostics.Report(invalid_kernel_arg);
      db.AddString(type->getCanonicalTypeInternal().getAsString());
      db.AddString(var);
    }

    arg_types.push_back(t);

    // Add them to the vectors
    program_parameters.push_back(var);
    arg_type_strs.push_back(type_str);

    // Convert type to a mlir type
    mlir::Type t = convertClangType(type, type_str, context);
    arg_types.push_back(t);
    if (type->getCanonicalTypeInternal().getAsString().find("vector<double>") !=
        std::string::npos) {
      var_attributes.push_back("double");
    } else {
      var_attributes.push_back("");
    }
  }

  // std::cout << "SRC:\n" << ss.str() << "\n";
@@ -70,11 +89,12 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
  // Get the return type as an mlir type,
  // as well as a string
  std::string ret_type_str = "";
  mlir::Type return_type = convertReturnType(D.getDeclSpec(), ret_type_str, context);
  mlir::Type return_type =
      convertReturnType(D.getDeclSpec(), ret_type_str, context);

  // Init the MLIRGen
  mlir_generator.initialize_mlirgen(kernel_name, arg_types, program_parameters,
                                    return_type);
                                    var_attributes, return_type);

  // Run the MLIRGen
  mlir_generator.mlirgen(src);
@@ -138,11 +158,37 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
  sss << "extern \"C\" { " << ret_type_str << " __internal_mlir_" << kernel_name
      << "(" << arg_type_strs[0];
  for (int i = 1; i < arg_type_strs.size(); i++) {
    sss << ", " << arg_type_strs[i];
    std::string type_name = arg_type_strs[i];
    if (arg_type_strs[i].find("qcor::qreg") != std::string::npos) {
      type_name = "Array*";
    } else if (arg_type_strs[i].find("std::vector") != std::string::npos) {
      type_name = "Array*";
    }
    sss << ", " << type_name;
  }
  sss << ");}\n";

  // Rewrite the function to call the internal function
  sss << getDeclText(PP, D).str() << "{\n";

  // Perform any argument translation
  // e.g. map qcor::qreg to Array* with q.raw_array();
  for (int i = 0; i < arg_type_strs.size(); i++) {
    if (arg_type_strs[i].find("qcor::qreg") != std::string::npos) {
      auto old_var_name = program_parameters[i];
      sss << "auto __tmp_internal_qreg_array_" << old_var_name << " = "
          << old_var_name << ".raw_array();\n";
      std::replace(program_parameters.begin(), program_parameters.end(),
                   old_var_name, "__tmp_internal_qreg_array_" + old_var_name);
    } else if (arg_type_strs[i].find("std::vector") != std::string::npos) {
      auto old_var_name = program_parameters[i];
      sss << "auto __tmp_internal_vector_array_" << old_var_name
          << " = qcor::qir::toArray(" << old_var_name << ");\n";
      std::replace(program_parameters.begin(), program_parameters.end(),
                   old_var_name, "__tmp_internal_vector_array_" + old_var_name);
    }
  }
  sss << "if (!initialized) initialize();\n";
  sss << "return __internal_mlir_" << kernel_name << "("
      << program_parameters[0];
  for (int i = 1; i < program_parameters.size(); i++) {
@@ -156,7 +202,11 @@ void Qasm3SyntaxHandler::GetReplacement(Preprocessor &PP, Declarator &D,
  return;
}

void Qasm3SyntaxHandler::AddToPredefines(llvm::raw_string_ostream &OS) {}
void Qasm3SyntaxHandler::AddToPredefines(llvm::raw_string_ostream &OS) {
  OS << "#include \"qir-qrt.hpp\"\n";
  OS << "#include \"qir-types-utils.hpp\"\n";
  OS << "using qcor::qubit;\n";
}
}  // namespace qcor

static SyntaxHandlerRegistry::Add<qcor::Qasm3SyntaxHandler> X(
+60 −7
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#include "Quantum/QuantumDialect.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/Type.h"
#include "clang/AST/TypeVisitor.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "llvm/Bitcode/BitcodeWriter.h"
@@ -18,9 +19,16 @@
#include "mlir/IR/Verifier.h"

namespace qcor {
mlir::Type convertClangType(const clang::Type* type, std::string& type_as_str,
                            mlir::MLIRContext& context) {
  if (auto BT = dyn_cast_or_null<clang::BuiltinType>(type)) {

class ClangToMLIRTypeVisitor
    : public clang::TypeVisitor<ClangToMLIRTypeVisitor, mlir::Type> {
 public:
  std::string& type_as_str;
  mlir::MLIRContext& context;

  ClangToMLIRTypeVisitor(mlir::MLIRContext& ctx, std::string& type_as_str_arg)
      : type_as_str(type_as_str_arg), context(ctx) {}
  mlir::Type VisitBuiltinType(const clang::BuiltinType* BT) {
    if (BT->isIntegerType()) {
      switch (BT->getKind()) {
        case BuiltinType::Short: {
@@ -66,18 +74,63 @@ mlir::Type convertClangType(const clang::Type* type, std::string& type_as_str,
      };
      // return CIL::IntegerTy::get(kind, qual, &mlirContext);
    } else if (BT->isFloatingType()) {
      // FIXME DO THIS
    }
  } else if (type->isStructuralType()) {
    if (type_as_str == "qubit") {
    return mlir::Type();
  }

  // This will handle qubit = Qubit*
  mlir::Type VisitTypedefType(const clang::TypedefType* r) {
    if (r->isPointerType()) {
      auto qual_type = r->getPointeeType();
      if (qual_type.getAsString().find("Qubit") != std::string::npos) {
        return mlir::OpaqueType::get(&context,
                                     mlir::Identifier::get("quantum", &context),
                                     llvm::StringRef("Qubit"));
      }
    }
    r->dump();
    return mlir::Type();
  }

  // This can handle qcor::qreg&
  mlir::Type VisitLValueReferenceType(const clang::LValueReferenceType* r) {

    auto qual_type = r->getPointeeType();
    if (qual_type.getAsString().find("qreg") != std::string::npos) {
      type_as_str = "qcor::qreg";
      return mlir::OpaqueType::get(&context,
                                   mlir::Identifier::get("quantum", &context),
                                   llvm::StringRef("Array"));
    }
    r->dump();
    return mlir::Type();
  }

  mlir::Type VisitElaboratedType(const clang::ElaboratedType* t) {
    if (t->getNamedType().getAsString().find("vector") != std::string::npos) {
      if (t->getNamedType().getAsString().find("double") != std::string::npos) {
        type_as_str = "std::vector<double>";
      } else if (t->getNamedType().getAsString().find("int") !=
                 std::string::npos) {
        type_as_str = "std::vector<int>";
      }
      return mlir::OpaqueType::get(&context,
                                   mlir::Identifier::get("quantum", &context),
                                   llvm::StringRef("Array"));
    }

    return mlir::Type();
  }
 
};

mlir::Type convertClangType(const clang::Type* type, std::string& type_as_str,
                            mlir::MLIRContext& context) {
  ClangToMLIRTypeVisitor visitor(context, type_as_str);
  return visitor.Visit(type);
}

mlir::Type convertReturnType(const clang::DeclSpec& spec,
                             std::string& ret_type_str,
                             mlir::MLIRContext& context) {
Loading