Commit 5632939f authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on alias: skeleton for MLIR lowering of assign and create array



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 25019220
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
OPENQASM 3;
include "qelib1.inc";
qubit q[6];
// myreg[0] refers to the qubit q[1], myreg[1] -> q[3], etc.
let myreg = q[1, 3, 5];

for i in [0:3] {
  x q[i];
}
 No newline at end of file
+2 −0
Original line number Diff line number Diff line
@@ -6,8 +6,10 @@ TEST(qasm3VisitorTester, checkAlias) {
  const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q[6];
x q[3];
// myreg[0] refers to the qubit q[1]
let myreg = q[1, 3, 5];
x myreg[0];
)#";
  auto mlir =
      qcor::mlir_compile("qasm3", src, "test", qcor::OutputType::MLIR, true);
+10 −6
Original line number Diff line number Diff line
@@ -38,8 +38,11 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
    auto str_attr = builder.getStringAttr(alias);
    auto integer_attr =
        mlir::IntegerAttr::get(builder.getI64Type(), n_expressions);
    mlir::Value alias_allocation = builder.create<mlir::quantum::QaliasArrayAllocOp>(
    mlir::Value alias_allocation =
        builder.create<mlir::quantum::QaliasArrayAllocOp>(
            location, array_type, integer_attr, str_attr);
    // Add the alias register to the symbol table
    symbol_table.add_symbol(alias, alias_allocation);

    auto counter = 0;
    for (auto expr : expressions) {
@@ -49,6 +52,7 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
          symbol_table.evaluate_constant_integer_expression(expr->getText());

      // get the src_extracted element from the original register
      auto qubit_type = get_custom_opaque_type("Qubit", builder.getContext());
      auto src_extracted = builder.create<mlir::quantum::ExtractQubitOp>(
          location, qubit_type, allocated_symbol,
          get_or_create_constant_integer_value(
+131 −0
Original line number Diff line number Diff line
@@ -115,6 +115,98 @@ class QallocOpLowering : public ConversionPattern {
  }
};

// The goal of QubitArrayAllocOpLowering is to lower all occurrences of the
// MLIR QuantumDialect createQubitArray to the MSFT QIR
// __quantum__rt__array_create_1d() quantum runtime function for Qubit*
// (create a generic array holding references to Qubit for aliasing purposes)
// as an LLVM MLIR Function and CallOp.
class QubitArrayAllocOpLowering : public ConversionPattern {
protected:
  // Constant string for runtime function name
  inline static const std::string qir_qubit_array_allocate =
      "__quantum__rt__array_create_1d";
  // Rudimentary symbol table, seen variables
  std::map<std::string, mlir::Value> &variables;
  /// Lower to:
  /// %Array* @__quantum__rt__array_create_1d(i32 %elementSizeInBytes, i64% nQubits) 
  /// where elementSizeInBytes = 8 (pointer size).
public:
  // Constructor, store seen variables
  explicit QubitArrayAllocOpLowering(MLIRContext *context,
                                     std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(mlir::quantum::QaliasArrayAllocOp::getOperationName(), 1,
                          context),
        variables(vars) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations, get location, parentModule
    // and the context
    auto loc = op->getLoc();
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();

    // First step is to get a reference to the Symbol Reference for the
    // __quantum__rt__array_create_1d QIR runtime function,
    // this will only declare it once and reuse each time it is seen
    FlatSymbolRefAttr symbol_ref;
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_qubit_array_allocate)) {
      symbol_ref = SymbolRefAttr::get(qir_qubit_array_allocate, context);
    } else {
      // prototype is (elementSize: int32, arraySize : int64) -> Array* :
      // qubit_array_ptr
      auto qubit_type = IntegerType::get(context, 64);
      auto element_size_type = IntegerType::get(context, 32);
      auto array_qbit_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto array_alloc_ftype = LLVM::LLVMFunctionType::get(
          array_qbit_type, llvm::ArrayRef<Type>{element_size_type, qubit_type},
          false);

      // Insert the function declaration
      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(
          parentModule->getLoc(), qir_qubit_array_allocate, array_alloc_ftype);
      symbol_ref = mlir::SymbolRefAttr::get(qir_qubit_array_allocate, context);
    }

    // Get as a QaliasArrayAllocOp, get its allocation size and qreg variable
    // name
    auto qallocOp = cast<mlir::quantum::QaliasArrayAllocOp>(op);
    auto size = qallocOp.size();
    auto qreg_name = qallocOp.name().str();

    Value create_size_int = rewriter.create<LLVM::ConstantOp>(
        loc, IntegerType::get(rewriter.getContext(), 64),
        rewriter.getIntegerAttr(rewriter.getI64Type(), size));

    Value element_size_int = rewriter.create<LLVM::ConstantOp>(
        loc, IntegerType::get(rewriter.getContext(), 32),
        rewriter.getIntegerAttr(
            rewriter.getI64Type(),
            /* element size = pointer size */ sizeof(void *)));

    auto array_qbit_type =
        LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
    auto qalloc_qir_call = rewriter.create<mlir::CallOp>(
        loc, symbol_ref, array_qbit_type,
        ArrayRef<Value>({element_size_int, create_size_int}));

    // Get the returned qubit array pointer Value
    auto qbit_array = qalloc_qir_call.getResult(0);

    // Remove the old QuantumDialect QallocOp
    rewriter.replaceOp(op, qbit_array);
    rewriter.eraseOp(op);
    // Save the qubit array variable to the symbol table
    variables.insert({qreg_name, qbit_array});

    return success();
  }
};

// declare void @__quantum__rt__qubit_release_array(%Array*)
class DeallocOpLowering : public ConversionPattern {
 protected:
@@ -592,6 +684,43 @@ class ExtractQubitOpConversion : public ConversionPattern {
  }
};

class AssignQubitOpConversion : public ConversionPattern {
protected:
  std::map<std::string, mlir::Value> &variables;

public:
  // CTor: store seen variables
  explicit AssignQubitOpConversion(MLIRContext *context,
                                   std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(
            mlir::quantum::AssignQubitOp::getOperationName(), 1,
            context),
        variables(vars) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();
    auto location = parentModule->getLoc();
    // Source and Destinations are Qubit* type
    auto dest = operands[0];
    auto src = operands[1];
    // Cast source pointer to Qubit**
    auto bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        src);
    // Store source (Qubit**) to destination
    // auto store_qubit_ptr =
    //     rewriter.create<LLVM::StoreOp>(location, bitcast.res(), dest);

    return success();
  }
};

class CreateStringLiteralOpLowering : public ConversionPattern {
 private:
  std::map<std::string, mlir::Value> &variables;
@@ -902,6 +1031,8 @@ void QuantumToLLVMLoweringPass::runOnOperation() {
  patterns.insert<DeallocOpLowering>(&getContext(), variables);
  patterns.insert<QRTInitOpLowering>(&getContext(), variables);
  patterns.insert<QRTFinalizeOpLowering>(&getContext(), variables);
  patterns.insert<QubitArrayAllocOpLowering>(&getContext(), variables);
  patterns.insert<AssignQubitOpConversion>(&getContext(), variables);

  // We want to completely lower to LLVM, so we use a `FullConversion`. This
  // ensures that only legal operations will remain after the conversion.