Commit c1e0ad29 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Using SCF ForOp to compute the integer power



The reason is that CFG block-based implementation is not compatible when this is nested inside a loop construct.
e.g., YieldOp at the end needs to be aware of which block is the exit one.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 27525e59
Loading
Loading
Loading
Loading
+85 −0
Original line number Diff line number Diff line
@@ -18,6 +18,91 @@ QCOR_EXPECT_TRUE(test < .01);
  EXPECT_FALSE(qcor::execute(global_const, "global_const"));
}

TEST(qasm3VisitorTester, checkPower) {
  const std::string power_test = R"#(OPENQASM 3;
include "qelib1.inc";
int j = 5;
int y = 2;
int test1 = 2^(j-y);
QCOR_EXPECT_TRUE(test1 == 8);
int test2 = j^(j-y);
QCOR_EXPECT_TRUE(test2 == 125);
)#";
  auto mlir = qcor::mlir_compile(power_test, "power_test",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute(power_test, "power_test"));
}

// Check QPE that has complex classical arithmetic:
TEST(qasm3VisitorTester, checkQPE) {
  const std::string qpe_test = R"#(OPENQASM 3;
const n_counting = 3;

// For this example, the oracle is the T gate 
// on the provided qubit
gate oracle b {
    t b;
}

// Inverse QFT subroutine on n_counting qubits
def iqft qubit[n_counting]:qq {
    for i in [0:n_counting/2] {
        swap qq[i], qq[n_counting-i-1];
    }
    for i in [0:n_counting-1] {
        h qq[i];
        int j = i + 1;
        int y = i;
        while (y >= 0) {
            double theta = -pi / (2^(j-y));
            cphase(theta) qq[j], qq[y];
            y -= 1;
        }
    }
    h qq[n_counting-1];
}

// Define some counting qubits
qubit counting[n_counting];

// Allocate the qubit we'll 
// put the initial state on
qubit state;

// We want T |1> = exp(2*i*pi*phase) |1> = exp(i*pi/4)
// compute phase, should be 1 / 8;

// Initialize to |1>
x state;

// Put all others in a uniform superposition
h counting;

// Loop over and create ctrl-U**2k
int repetitions = 1;
for i in [0:n_counting] {
    ctrl @ pow(repetitions) @ oracle counting[i], state;
    repetitions *= 2;
}

// Run inverse QFT 
iqft counting;

// Now lets measure the counting qubits
bit c[n_counting];
measure counting -> c;

// Backend is QPP which is lsb, 
// so return should be 100
print(c);
)#";
  auto mlir = qcor::mlir_compile(qpe_test, "qpe_test",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute(qpe_test, "qpe_test"));
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+34 −90
Original line number Diff line number Diff line

#include "expression_handler.hpp"
#include "mlir/Dialect/SCF/SCF.h"

using namespace qasm3;

namespace qcor {
@@ -555,100 +557,42 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(
      // }

      // Will need to compute this via a loop
      auto tmp = get_or_create_constant_integer_value(
          0, location, builder.getI64Type(), symbol_table, builder);
      auto tmp2 = get_or_create_constant_index_value(0, location, 64,
                                                     symbol_table, builder);
      auto tmp3 = get_or_create_constant_integer_value(
          1, location, lhs_element_type, symbol_table, builder);
      llvm::ArrayRef<mlir::Value> zero_index(tmp2);
      // Upper bound = exponent (rhs)
      mlir::Value ubs_val = rhs;
      if (!ubs_val.getType().isa<mlir::IndexType>()) {
        ubs_val = builder.create<mlir::IndexCastOp>(
            location, builder.getIndexType(), ubs_val);
      }

      // Create the result memref, initialized to 1
      llvm::ArrayRef<int64_t> shaperef{};
      auto mem_type = mlir::MemRefType::get(shaperef, lhs_element_type);

      auto integer_attr2 = mlir::IntegerAttr::get(lhs_element_type, 0);
      
      assert(integer_attr2.getType().cast<mlir::IntegerType>().isSignless());
      auto ret2 = builder.create<mlir::ConstantOp>(location, integer_attr2);
      
      auto integer_attr3 = mlir::IntegerAttr::get(lhs_element_type, 1);
      assert(integer_attr3.getType().cast<mlir::IntegerType>().isSignless());
      auto ret3 = builder.create<mlir::ConstantOp>(location, integer_attr3);

      mlir::Value loop_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI64Type()));
      builder.create<mlir::StoreOp>(location, ret2, loop_var_memref);

      mlir::Value product_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, lhs_element_type));
      builder.create<mlir::StoreOp>(location, ret3, product_memref);

      auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), 1);
      auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
      auto b_val = rhs;
      auto c_val = ret;

      // Save the current builder point
      // auto savept = builder.saveInsertionPoint();
      auto loaded_var = builder.create<mlir::LoadOp>(
          location, loop_var_memref);  //, zero_index);

      auto savept = builder.saveInsertionPoint();
      auto currRegion = builder.getBlock()->getParent();
      auto headerBlock = builder.createBlock(currRegion, currRegion->end());
      auto bodyBlock = builder.createBlock(currRegion, currRegion->end());
      auto incBlock = builder.createBlock(currRegion, currRegion->end());
      mlir::Block* exitBlock =
          builder.createBlock(currRegion, currRegion->end());
      builder.restoreInsertionPoint(savept);

      builder.create<mlir::BranchOp>(location, headerBlock);
      builder.setInsertionPointToStart(headerBlock);

      auto load = builder.create<mlir::LoadOp>(location, loop_var_memref)
                      .result();  //, zero_index);

      if (load.getType().getIntOrFloatBitWidth() <
          b_val.getType().getIntOrFloatBitWidth()) {
        load = builder.create<mlir::ZeroExtendIOp>(location, load,
                                                   b_val.getType());
      } else if (b_val.getType().getIntOrFloatBitWidth() <
                 load.getType().getIntOrFloatBitWidth()) {
        b_val = builder.create<mlir::ZeroExtendIOp>(location, b_val,
                                                    load.getType());
      }

      auto cmp = builder.create<mlir::CmpIOp>(
          location, mlir::CmpIPredicate::slt, load, b_val);
      builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

      builder.setInsertionPointToStart(bodyBlock);
      // body needs to load the loop variable
      auto x = builder.create<mlir::LoadOp>(location,
                                            product_memref);  //, zero_index);

      auto mult = builder.create<mlir::MulIOp>(location, x, lhs);
      builder.create<mlir::StoreOp>(location, mult, product_memref);  //,
      // llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));

      builder.create<mlir::BranchOp>(location, incBlock);

      builder.setInsertionPointToStart(incBlock);
      auto load_inc = builder.create<mlir::LoadOp>(
          location, loop_var_memref);  //, zero_index);
      auto add = builder.create<mlir::AddIOp>(location, load_inc, c_val);

      builder.create<mlir::StoreOp>(location, add, loop_var_memref);  //,
      // llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));

      builder.create<mlir::BranchOp>(location, headerBlock);

      builder.setInsertionPointToStart(exitBlock);

      symbol_table.set_last_created_block(exitBlock);
      builder.create<mlir::StoreOp>(
          location,
          builder.create<mlir::ConstantOp>(
              location, mlir::IntegerAttr::get(lhs_element_type, 1)),
          product_memref);

      current_value = builder.create<mlir::LoadOp>(
          location, product_memref);  //, zero_index);
      // Lower bound = 0
      mlir::Value lbs_val = get_or_create_constant_index_value(
          0, location, 64, symbol_table, builder);
      mlir::Value step_val = get_or_create_constant_index_value(
          1, location, 64, symbol_table, builder);
      builder.create<mlir::scf::ForOp>(
          location, lbs_val, ubs_val, step_val, llvm::None,
          [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
              mlir::Value iv, mlir::ValueRange itrArgs) {
            mlir::OpBuilder::InsertionGuard guard(nestedBuilder);
            // Load - Multiply - Store
            auto x =
                nestedBuilder.create<mlir::LoadOp>(nestedLoc, product_memref);
            auto mult = nestedBuilder.create<mlir::MulIOp>(nestedLoc, x, lhs);
            nestedBuilder.create<mlir::StoreOp>(nestedLoc, mult,
                                                product_memref);
            nestedBuilder.create<mlir::scf::YieldOp>(nestedLoc);
          });
      current_value = builder.create<mlir::LoadOp>(location, product_memref);
      return 0;
    } else if (lhs_element_type.isa<mlir::FloatType>()) {
      if (!rhs_element_type.isa<mlir::FloatType>()) {