Commit 76183e1b authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Merge branch 'master' of https://github.com/ornl-qci/qcor

parents 4980a014 1cb2f848
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -29,9 +29,10 @@ struct QuantumInlinerInterface : public DialectInlinerInterface {
  // FIXME: there is a weird error when qalloc is inlined at MLIR level
  // hence, just allow VSOp to be inlined for the timebeing.
  // i.e. all quantum subroutines that only contain VSOp's can be inlined.
  bool isLegalToInline(Operation *op, Region *regione, bool,
  bool isLegalToInline(Operation *op, Region *region, bool,
                       BlockAndValueMapping &) const final {
    if (dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(op)) {
    if (dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(op) ||
        dyn_cast_or_null<mlir::quantum::ExtractQubitOp>(op)) {
      return true;
    }

+152 −0
Original line number Diff line number Diff line
@@ -152,6 +152,158 @@ cx q[0], q[1];
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__qubit_release_array"), 0);
}

TEST(qasm3PassManagerTester, checkLoopUnroll) {
  // Unroll the loop:
  // cancel all X gates; combine rx
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
qubit q[2];
for i in [0:10] {
    x q[0];
    rx(0.123) q[1];
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rx remains (combine all angles)
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
}

TEST(qasm3PassManagerTester, checkLoopUnrollTrotter) {
  // Unroll the loop:
  // Trotter decompose
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
qubit qq[2];
for i in [0:100] {
    h qq;
    cx qq[0], qq[1];
    rx(0.0123) qq[1];
    cx qq[0], qq[1];
    h qq;
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rx remains (combine all angles)
  // 2 Hadamard before + 1 CX before
  // 2 Hadamard after + 1 CX after
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 7);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__h"), 4);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cnot"), 2);
}

TEST(qasm3PassManagerTester, checkLoopUnrollWithInline) {
  // Unroll the loop and inline
  // Trotter decompose
  // Note: using the inv (adjoint) modifier is not supported
  // since it is a runtime feature...
  // hence, we need to make the adjoint explicit.
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
def cnot_ladder() qubit[4]:q {
  h q[0];
  h q[1];
  cx q[0], q[1];
  cx q[1], q[2];
  cx q[2], q[3];
}

def cnot_ladder_inv() qubit[4]:q {
  cx q[2], q[3];
  cx q[1], q[2];
  cx q[0], q[1];
  h q[1];
  h q[0];
}

qubit q[4];
double theta = 0.01;
for i in [0:100] {
  cnot_ladder q;
  rz(theta) q[3];
  cnot_ladder_inv q;
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rz remains (combine all angles)
  // 2 Hadamard before + 3 CX before
  // 2 Hadamard after + 3 CX after
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 11);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rz"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__h"), 4);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cnot"), 6);
}

TEST(qasm3PassManagerTester, checkAffineLoopRevert) {
  // Check loop with negative step:
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
def cnot_ladder() qubit[4]:q {
  h q;
  for i in [0:3] {
    cx q[i], q[i + 1];
  }
}

def cnot_ladder_inv() qubit[4]:q {
  for i in [3:-1:0] {
    cx q[i-1], q[i];
  }
  
  h q;
}

qubit q[4];
double theta = 0.01;
for i in [0:100] {
  cnot_ladder q;
  rz(theta) q[3];
  cnot_ladder_inv q;
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rz remains (combine all angles)
  // 4 Hadamard before + 3 CX before
  // 4 Hadamard after + 3 CX after
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 15);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rz"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__h"), 8);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cnot"), 6);
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+19 −4
Original line number Diff line number Diff line
@@ -122,6 +122,10 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
          current_value = builder.create<mlir::ZeroExtendIOp>(
              location, current_value, builder.getI64Type());
        }
        if (!current_value.getType().isa<mlir::IntegerType>()) {
          current_value = builder.create<mlir::IndexCastOp>(
              location, builder.getI64Type(), current_value);
        }
        update_current_value(builder.create<mlir::quantum::ExtractQubitOp>(
            location, get_custom_opaque_type("Qubit", builder.getContext()),
            indexed_variable_value, current_value));
@@ -404,8 +408,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
        }

        createOp<mlir::AddFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
                 rhs.getType().isa<mlir::IntegerType>()) {
      } else if ((lhs.getType().isa<mlir::IntegerType>() ||
                  lhs.getType().isa<mlir::IndexType>()) &&
                 (rhs.getType().isa<mlir::IntegerType>() ||
                  rhs.getType().isa<mlir::IndexType>())) {
        if (lhs.getType().getIntOrFloatBitWidth() <
            rhs.getType().getIntOrFloatBitWidth()) {
          lhs =
@@ -416,6 +422,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
          rhs =
              builder.create<mlir::ZeroExtendIOp>(location, rhs, lhs.getType());
        }

        if (lhs.getType() != rhs.getType()) {
          rhs = builder.create<mlir::IndexCastOp>(location, lhs.getType(), rhs);
        }
        createOp<mlir::AddIOp>(location, lhs, rhs).result();
      } else {
        printErrorMessage("Could not perform addition, incompatible types: ",
@@ -434,8 +444,13 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
        }

        createOp<mlir::SubFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
                 rhs.getType().isa<mlir::IntegerType>()) {
      } else if ((lhs.getType().isa<mlir::IntegerType>() ||
                  lhs.getType().isa<mlir::IndexType>()) &&
                 (rhs.getType().isa<mlir::IntegerType>() ||
                  rhs.getType().isa<mlir::IndexType>())) {
        if (lhs.getType() != rhs.getType()) {
          rhs = builder.create<mlir::IndexCastOp>(location, lhs.getType(), rhs);
        }
        createOp<mlir::SubIOp>(location, lhs, rhs).result();
      } else {
        printErrorMessage("Could not perform subtraction, incompatible types: ",
+142 −53
Original line number Diff line number Diff line
@@ -5,6 +5,64 @@ using symbol_table_t = exprtk::symbol_table<double>;
using expression_t = exprtk::expression<double>;
using parser_t = exprtk::parser<double>;

namespace {
/// Creates a single affine "for" loop, iterating from lbs to ubs with
/// the given step.
/// to construct the body of the loop and is passed the induction variable.
void affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
                       std::function<void(mlir::Value)> bodyBuilderFn,
                       mlir::OpBuilder &builder, mlir::Location &loc) {
  if (!ubs_val.getType().isa<mlir::IndexType>()) {
    ubs_val =
        builder.create<mlir::IndexCastOp>(loc, builder.getIndexType(), ubs_val);
  }
  if (!lbs_val.getType().isa<mlir::IndexType>()) {
    lbs_val =
        builder.create<mlir::IndexCastOp>(loc, builder.getIndexType(), lbs_val);
  }
  // Note: Affine for loop only accepts **positive** step:
  // The stride, represented by step, is a positive constant integer which
  // defaults to “1” if not present.
  assert(step != 0);
  if (step > 0) {
    mlir::ValueRange lbs(lbs_val);
    mlir::ValueRange ubs(ubs_val);
    // Create the actual loop
    builder.create<mlir::AffineForOp>(
        loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
        builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None,
        [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
            mlir::Value iv, mlir::ValueRange itrArgs) {
          mlir::OpBuilder::InsertionGuard guard(nestedBuilder);
          bodyBuilderFn(iv);
          nestedBuilder.create<mlir::AffineYieldOp>(nestedLoc);
        });
  } else {
    // Negative step:
    // a -> b step c (minus)
    // -a -> -b step c (plus) and minus the loop var
    mlir::Value minus_one = builder.create<mlir::ConstantOp>(
        loc, mlir::IntegerAttr::get(lbs_val.getType(), -1));
    lbs_val = builder.create<mlir::MulIOp>(loc, lbs_val, minus_one).result();
    ubs_val = builder.create<mlir::MulIOp>(loc, ubs_val, minus_one).result();
    mlir::ValueRange lbs(lbs_val);
    mlir::ValueRange ubs(ubs_val);
    builder.create<mlir::AffineForOp>(
        loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
        builder.getMultiDimIdentityMap(ubs.size()), -step, llvm::None,
        [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
            mlir::Value iv, mlir::ValueRange itrArgs) {
          mlir::OpBuilder::InsertionGuard guard(nestedBuilder);
          mlir::Value minus_one_idx = nestedBuilder.create<mlir::ConstantOp>(
              nestedLoc, mlir::IntegerAttr::get(iv.getType(), -1));
          bodyBuilderFn(
              nestedBuilder.create<mlir::MulIOp>(nestedLoc, iv, minus_one_idx)
                  .result());
          nestedBuilder.create<mlir::AffineYieldOp>(nestedLoc);
        });
  }
}
} // namespace
namespace qcor {
antlrcpp::Any qasm3_visitor::visitLoopStatement(
    qasm3Parser::LoopStatementContext* context) {
@@ -215,6 +273,34 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        }
      }

      const std::string program_block_str = program_block->getText();
      // std::cout << "HOWDY:\n" << program_block_str << "\n";

      // HACK: Currently, we don't handle 'if', 'break', 'continue'
      // in the Affine for loop yet.
      if (program_block_str.find("if") == std::string::npos &&
          program_block_str.find("break") == std::string::npos &&
          program_block_str.find("continue") == std::string::npos &&
          // This is equivalent to an "if"
          program_block_str.find("QCOR_EXPECT_TRUE") == std::string::npos &&
          // We can only handle nested for loops if the inner one is also an
          // affine one For now, don't do that since we're not sure.
          program_block_str.find("for") == std::string::npos &&
          // While loop is not converted to affine yet.
          program_block_str.find("while") == std::string::npos) {
        // Can use Affine for loop....
        affineLoopBuilder(
            a_value, b_value, c,
            [&](mlir::Value loop_var) {
              // Create a new scope for the for loop
              symbol_table.enter_new_scope();
              symbol_table.add_symbol(idx_var_name, loop_var, {}, true);
              visitChildren(program_block);
              symbol_table.exit_scope();
            },
            builder, location);
      } else {
        // Need to use the legacy for loop construction for now...
        // Create a new scope for the for loop
        symbol_table.enter_new_scope();

@@ -226,22 +312,24 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        // Save the current builder point
        // auto savept = builder.saveInsertionPoint();
      auto loaded_var = builder.create<mlir::LoadOp>(location, loop_var_memref);
        auto loaded_var =
            builder.create<mlir::LoadOp>(location, loop_var_memref);

        symbol_table.add_symbol(idx_var_name, loaded_var, {}, true);

        // Strategy...

      // We need to create a header block to check that loop var is still valid
      // it will branch at the end to the body or the exit
        // We need to create a header block to check that loop var is still
        // valid it will branch at the end to the body or the exit

        // Then we create the body block, it should branch to the incrementor
        // block

        // Then we create the incrementor block, it should branch back to header

      // Any downstream children that will create blocks will need to know what
      // the fallback block for them is, and it should be the incrementor block
        // Any downstream children that will create blocks will need to know
        // what the fallback block for them is, and it should be the incrementor
        // block
        auto savept = builder.saveInsertionPoint();
        auto currRegion = builder.getBlock()->getParent();
        auto headerBlock = builder.createBlock(currRegion, currRegion->end());
@@ -256,7 +344,9 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        auto load = builder.create<mlir::LoadOp>(location, loop_var_memref);
        auto cmp = builder.create<mlir::CmpIOp>(
          location, c > 0 ? mlir::CmpIPredicate::slt : mlir::CmpIPredicate::sge, load, b_value);
            location,
            c > 0 ? mlir::CmpIPredicate::slt : mlir::CmpIPredicate::sge, load,
            b_value);
        builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

        builder.setInsertionPointToStart(bodyBlock);
@@ -280,7 +370,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        auto add = builder.create<mlir::AddIOp>(location, load_inc, c_value);


        builder.create<mlir::StoreOp>(location, add, loop_var_memref);

        builder.create<mlir::BranchOp>(location, headerBlock);
@@ -290,7 +379,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        symbol_table.set_last_created_block(exitBlock);

        symbol_table.exit_scope();

      }
    } else {
      printErrorMessage(
          "For loops must be of form 'for i in {SET}' or 'for i in [RANGE]'.");
+4 −1
Original line number Diff line number Diff line
@@ -94,7 +94,10 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
              mlir::Identifier::get("quantum", builder.getContext());
          auto qubit_type = mlir::OpaqueType::get(builder.getContext(), dialect,
                                                  qubit_type_name);

          if (!qbit.getType().isa<mlir::IntegerType>()) {
            qbit = builder.create<mlir::IndexCastOp>(
                location, builder.getI64Type(), qbit);
          }
          value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qubits, qbit);
        } else {
Loading