Commit 8768407c authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

qassign to take two regs and two indices



Realized that we cannot extract from an unintialized qreg since qextract involves LoadOp.

Hence, qassign to take destination and source register arrays and indices explicitly.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 5632939f
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -29,11 +29,10 @@ def ExtractQubitOp : QuantumOp<"qextract", []> {
    let results = (outs QubitType:$qbit);
}

// Assign a qubit pointer (extracted w/ qextract) to an alias pointer. 
// Signature: void qassign(Qubit* destination, Qubit* source)
// where destination and source are retrieved by qextract.
// Assign a qubit pointer (specified by the Qubit array and index) to an alias pointer. 
// Signature: void qassign(Array* destination_array, int destination_idx, Array* source_array, int source_idx)
def AssignQubitOp : QuantumOp<"qassign", []> {
    let arguments = (ins QubitType:$dest, QubitType:$src);
    let arguments = (ins ArrayType:$dest_qreg, AnyInteger:$dest_idx, ArrayType:$src_qreg, AnyInteger:$src_idx);
    let results = (outs);
}

+3 −0
Original line number Diff line number Diff line
@@ -10,6 +10,9 @@ x q[3];
// myreg[0] refers to the qubit q[1]
let myreg = q[1, 3, 5];
x myreg[0];
h myreg[1];
let alias = q[0, 2, 4];
cx alias[1], myreg[2];
)#";
  auto mlir =
      qcor::mlir_compile("qasm3", src, "test", qcor::OutputType::MLIR, true);
+7 −16
Original line number Diff line number Diff line
@@ -50,23 +50,14 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
      // to the correct element of the alias array
      auto idx =
          symbol_table.evaluate_constant_integer_expression(expr->getText());

      // get the src_extracted element from the original register
      auto qubit_type = get_custom_opaque_type("Qubit", builder.getContext());
      auto src_extracted = builder.create<mlir::quantum::ExtractQubitOp>(
          location, qubit_type, allocated_symbol,
          get_or_create_constant_integer_value(
              idx, location, builder.getI64Type(), symbol_table, builder));
      // get the dest_extracted element from the alias register
      auto dest_extracted = builder.create<mlir::quantum::ExtractQubitOp>(
          location, qubit_type, alias_allocation,
          get_or_create_constant_integer_value(
              counter, location, builder.getI64Type(), symbol_table, builder));
      auto dest_idx = get_or_create_constant_integer_value(
          counter, location, builder.getI64Type(), symbol_table, builder);
      auto src_idx = get_or_create_constant_integer_value(
          idx, location, builder.getI64Type(), symbol_table, builder);
      ++counter;
      // use extracted with a new qassign dialect operation.
      // void qAssign(Qubit* dest, Qubit* src)
      builder.create<mlir::quantum::AssignQubitOp>(location, dest_extracted,
                                                   src_extracted);

      builder.create<mlir::quantum::AssignQubitOp>(
          location, alias_allocation, dest_idx, allocated_symbol, src_idx);
    }

  } else if (auto range_def = context->indexIdentifier()->rangeDefinition()) {
+71 −14
Original line number Diff line number Diff line
@@ -199,10 +199,10 @@ public:

    // Remove the old QuantumDialect QallocOp
    rewriter.replaceOp(op, qbit_array);
    rewriter.eraseOp(op);
    // Save the qubit array variable to the symbol table
    variables.insert({qreg_name, qbit_array});

    // std::cout << "Array 1D alloc:\n";
    // parentModule.dump();
    return success();
  }
};
@@ -692,8 +692,7 @@ public:
  // CTor: store seen variables
  explicit AssignQubitOpConversion(MLIRContext *context,
                                   std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(
            mlir::quantum::AssignQubitOp::getOperationName(), 1,
      : ConversionPattern(mlir::quantum::AssignQubitOp::getOperationName(), 1,
                          context),
        variables(vars) {}

@@ -704,19 +703,77 @@ public:
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();
    auto location = parentModule->getLoc();
    // Source and Destinations are Qubit* type
    auto dest = operands[0];
    auto src = operands[1];
    // Cast source pointer to Qubit**
    auto bitcast = rewriter.create<LLVM::BitcastOp>(
    // Unpack destination and source array and indices
    auto dest_array = operands[0];
    auto dest_idx = operands[1];
    auto src_array = operands[2];
    auto src_idx = operands[3];
    FlatSymbolRefAttr array_get_elem_fn_ptr = [&]() {
      static const std::string qir_get_qubit_from_array =
          "__quantum__rt__array_get_element_ptr_1d";
      if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(
              qir_get_qubit_from_array)) {
        return SymbolRefAttr::get(qir_get_qubit_from_array, context);
      } else {
        // prototype should be (int64* : qreg, int64 : element) -> int64* :
        // qubit
        auto qubit_array_type =
            LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
        auto qubit_index_type = IntegerType::get(context, 64);
        // ret is i8*
        auto qbit_element_ptr_type =
            LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

        auto get_ptr_qbit_ftype = LLVM::LLVMFunctionType::get(
            qbit_element_ptr_type,
            llvm::ArrayRef<Type>{qubit_array_type, qubit_index_type}, false);

        PatternRewriter::InsertionGuard insertGuard(rewriter);
        rewriter.setInsertionPointToStart(parentModule.getBody());
        rewriter.create<LLVM::LLVMFuncOp>(location, qir_get_qubit_from_array,
                                          get_ptr_qbit_ftype);

        return mlir::SymbolRefAttr::get(qir_get_qubit_from_array, context);
      }
    }();

    // Create the CallOp for the get element ptr 1d function
    auto get_dest_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, array_get_elem_fn_ptr,
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8)),
        llvm::makeArrayRef(std::vector<mlir::Value>{dest_array, dest_idx}));

    auto get_src_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, array_get_elem_fn_ptr,
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8)),
        llvm::makeArrayRef(std::vector<mlir::Value>{src_array, src_idx}));

    // Load source qubit
    auto src_bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        src);
    // Store source (Qubit**) to destination
    // auto store_qubit_ptr =
    //     rewriter.create<LLVM::StoreOp>(location, bitcast.res(), dest);
        get_src_qbit_qir_call.getResult(0));

    auto real_casted_src_qubit = rewriter.create<LLVM::LoadOp>(
        location,
        LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)),
        src_bitcast.res());

    // Destination: just cast the raw ptr to Qubit** to store the source Qubit*
    // to. Get the destination raw ptr (int8) and cast to Qubit**
    auto dest_bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        get_dest_qbit_qir_call.getResult(0));

    // Store source (Qubit*) to destination (Qubit**)
    rewriter.create<LLVM::StoreOp>(location, real_casted_src_qubit,
                                   dest_bitcast);
    rewriter.eraseOp(op);
    // std::cout << "After assign:\n";
    // parentModule.dump();
    return success();
  }
};