Unverified Commit 45302eb9 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #192 from tnguyen-ornl/tnguyen/mlir-const-eval-affine-loop

Consteval loop vars if possible
parents 891cb352 b57299f1
Loading
Loading
Loading
Loading
Loading
+52 −0
Original line number Diff line number Diff line
@@ -385,6 +385,58 @@ cx first_and_last_qubit[0], first_and_last_qubit[1];
  }
}

TEST(qasm3PassManagerTester, checkConstEvalLoopUnroll) {
  {
    // Unroll the loop with const vars as loop bounds
    const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";

const n = 3;
qubit qb;

for i in [0:2*n] {
 x qb;
}
)#";
    auto llvm =
        qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
    std::cout << "LLVM:\n" << llvm << "\n";

    // Get the main kernel section only
    llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
    const auto last = llvm.find_first_of("}");
    llvm = llvm.substr(0, last + 1);
    std::cout << "LLVM:\n" << llvm << "\n";
    // Cancel all
    EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 0);
  }

  {
    // Unroll the loop with const vars as loop bounds
    const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";

const n = 3;
qubit qb;

for i in [0:2*n + 1] {
 x qb;
}
)#";
    auto llvm =
        qcor::mlir_compile(src, "test_kernel1", qcor::OutputType::LLVMIR, false);
    std::cout << "LLVM:\n" << llvm << "\n";

    // Get the main kernel section only
    llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel1"));
    const auto last = llvm.find_first_of("}");
    llvm = llvm.substr(0, last + 1);
    std::cout << "LLVM:\n" << llvm << "\n";
    // One X gate left
    EXPECT_EQ(countSubstring(llvm, "__quantum__qis__x"), 1);
  }
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+19 −3
Original line number Diff line number Diff line
@@ -26,7 +26,8 @@ void ScopedSymbolTable::replace_symbol(mlir::Value old_value,
  }
}

int64_t ScopedSymbolTable::evaluate_constant_integer_expression(
std::optional<int64_t>
ScopedSymbolTable::try_evaluate_constant_integer_expression(
    const std::string expr_str) {
  auto all_constants = get_constant_integer_variables();
  std::vector<std::string> variable_names;
@@ -58,7 +59,22 @@ int64_t ScopedSymbolTable::evaluate_constant_integer_expression(
  if (parser.compile(expr_str, expr)) {
    ref = expr.value();
  } else {
    printErrorMessage("Failed to evaluate cnostant integer expression: " +
    // Cannot eval to a const int.
    return std::nullopt;
  }

  return (int64_t)ref;
}

int64_t ScopedSymbolTable::evaluate_constant_integer_expression(
    const std::string expr_str) {
  const std::optional<int64_t> try_eval =
      try_evaluate_constant_integer_expression(expr_str);
  double ref = 0.0;
  if (try_eval.has_value()) {
    ref = try_eval.value();
  } else {
    printErrorMessage("Failed to evaluate constant integer expression: " +
                      expr_str + ". Must be a constant integer type.");
  }

+4 −0
Original line number Diff line number Diff line
@@ -221,7 +221,11 @@ class ScopedSymbolTable {
    return ret;
  }

  // Eval a const int expression (throw if failed)
  int64_t evaluate_constant_integer_expression(const std::string expr);
  // Returns null if this expression cannot be const-eval to an integer value.
  std::optional<int64_t>
  try_evaluate_constant_integer_expression(const std::string expr);

  mlir::FuncOp get_seen_function(const std::string name) {
    if (!seen_functions.count(name)) {
+48 −19
Original line number Diff line number Diff line
@@ -230,6 +230,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          c_value = get_or_create_constant_integer_value(c, location, int_type,
                                                         symbol_table, builder);

      const auto const_eval_a_val =
          symbol_table.try_evaluate_constant_integer_expression(
              range->expression(0)->getText());
      if (const_eval_a_val.has_value()) {
        // std::cout << "A val = " << const_eval_a_val.value() << "\n";
        a_value = get_or_create_constant_integer_value(const_eval_a_val.value(),
                                                       location, int_type,
                                                       symbol_table, builder);
      } else {
        qasm3_expression_generator exp_generator(builder, symbol_table,
                                                 file_name);
        exp_generator.visit(range->expression(0));
@@ -237,8 +246,18 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        if (a_value.getType().isa<mlir::MemRefType>()) {
          a_value = builder.create<mlir::LoadOp>(location, a_value);
        }
      }

      if (n_expr == 3) {
        const auto const_eval_b_val =
          symbol_table.try_evaluate_constant_integer_expression(
              range->expression(2)->getText());
        if (const_eval_b_val.has_value()) {
          // std::cout << "B val = " << const_eval_b_val.value() << "\n";
          b_value = get_or_create_constant_integer_value(
              const_eval_b_val.value(), location, int_type, symbol_table,
              builder);
        } else {
          qasm3_expression_generator exp_generator(builder, symbol_table,
                                                   file_name);
          exp_generator.visit(range->expression(2));
@@ -246,7 +265,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          if (b_value.getType().isa<mlir::MemRefType>()) {
            b_value = builder.create<mlir::LoadOp>(location, b_value);
          }

        }
        if (symbol_table.has_symbol(range->expression(1)->getText())) {
          printErrorMessage("You must provide loop step as a constant value.",
                            context);
@@ -263,6 +282,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
              c, location, a_value.getType(), symbol_table, builder);
        }

      } else {
        const auto const_eval_b_val =
            symbol_table.try_evaluate_constant_integer_expression(
                range->expression(1)->getText());
        if (const_eval_b_val.has_value()) {
          // std::cout << "B val = " << const_eval_b_val.value() << "\n";
          b_value = get_or_create_constant_integer_value(
              const_eval_b_val.value(), location, int_type, symbol_table,
              builder);
        } else {
          qasm3_expression_generator exp_generator(builder, symbol_table,
                                                   file_name);
@@ -272,6 +300,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
            b_value = builder.create<mlir::LoadOp>(location, b_value);
          }
        }
      }

      const std::string program_block_str = program_block->getText();
      // std::cout << "HOWDY:\n" << program_block_str << "\n";