Unverified Commit cf927e1e authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #188 from tnguyen-ornl/tnguyen/add-mlir-opt

Add Reset optimization pass and handling qubit SSA variables in array aliasing
parents ddf60396 9f88a90d
Loading
Loading
Loading
Loading
Loading
+81 −0
Original line number Diff line number Diff line
@@ -32,6 +32,30 @@ cx q[0], q[1];
  EXPECT_TRUE(llvm.find("__quantum__qis") == std::string::npos);
}

TEST(qasm3PassManagerTester, checkResetSimplification) {
  const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q[2];

reset q[0];
x q[1];
reset q[0];
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";

  // Get the main kernel section only
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // One reset and one X:
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 2);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__x"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__reset"), 1);
}

TEST(qasm3PassManagerTester, checkRotationMerge) {
  const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";
@@ -304,6 +328,63 @@ for i in [0:100] {
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cnot"), 6);
}

TEST(qasm3PassManagerTester, checkQubitArrayAlias) {
  {
    // Check SSA value chain with alias
    // h-t-h == rx (t is equiv. to rz)
    const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";

qubit q[6];
let my_reg = q[1, 3, 5];

h q[1];
t my_reg[0];
h q[1];
)#";
    auto llvm =
        qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
    std::cout << "LLVM:\n" << llvm << "\n";

    // Get the main kernel section only (there is the oracle LLVM section as
    // well)
    llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
    const auto last = llvm.find_first_of("}");
    llvm = llvm.substr(0, last + 1);
    std::cout << "LLVM:\n" << llvm << "\n";
    EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 1);
    // One Rx
    EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
  }

  {
    // Check optimization can work with alias array
    const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";

qubit q[4];
let first_and_last_qubit = q[0] || q[3];

cx q[0], q[3];
cx first_and_last_qubit[0], first_and_last_qubit[1];
)#";
    auto llvm =
        qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
    std::cout << "LLVM:\n" << llvm << "\n";

    // Get the main kernel section only (there is the oracle LLVM section as
    // well)
    llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
    const auto last = llvm.find_first_of("}");
    llvm = llvm.substr(0, last + 1);
    std::cout << "LLVM:\n" << llvm << "\n";
    // Cancel all => No gates, extract, or alloc/dealloc:
    EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 0);
    // Make sure all runtime (alias construction) functions are removed as well.
    EXPECT_EQ(countSubstring(llvm, "__quantum__"), 0);
  }
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+4 −4
Original line number Diff line number Diff line
@@ -71,8 +71,8 @@ mlir::Type get_custom_opaque_type(const std::string& type,
mlir::Value get_or_extract_qubit(const std::string &qreg_name,
                                 const std::size_t idx, mlir::Location location,
                                 ScopedSymbolTable &symbol_table,
                                 mlir::OpBuilder& builder, std::string prepended_st_name) {
  auto key = prepended_st_name + qreg_name + std::to_string(idx);
                                 mlir::OpBuilder &builder) {
  auto key = symbol_table.array_qubit_symbol_name(qreg_name, idx);
  if (symbol_table.has_symbol(key)) {
    return symbol_table.get_symbol(key);  // global_symbol_table[key];
  } else {
+3 −3
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ mlir::Type get_custom_opaque_type(const std::string& type,
mlir::Value get_or_extract_qubit(const std::string &qreg_name,
                                 const std::size_t idx, mlir::Location location,
                                 ScopedSymbolTable &symbol_table,
                                 mlir::OpBuilder& builder, std::string prepended_st_name = "");
                                 mlir::OpBuilder &builder);

mlir::Value get_or_create_constant_integer_value(
    const std::size_t idx, mlir::Location location, mlir::Type int_like_type,
+122 −6
Original line number Diff line number Diff line
@@ -11,7 +11,102 @@
#include "mlir/IR/BuiltinTypes.h"

namespace qcor {
using SymbolTable = std::map<std::string, mlir::Value>;
// using SymbolTable = std::map<std::string, mlir::Value>;
struct SymbolTable {
  std::map<std::string, mlir::Value>::iterator begin() {
    return var_name_to_value.begin();
  }
  std::map<std::string, mlir::Value>::iterator end() {
    return var_name_to_value.end();
  }

  // Check if we have this symbol:
  // If this is a *root* (master) symbol, backed by a mlir::Value
  // or this is an alias (by reference), normally only for Qubits (SSA values)
  bool has_symbol(const std::string &var_name) {
    if (var_name_to_value.find(var_name) != var_name_to_value.end()) {
      return true;
    }
    const auto alias_name_check_iter =
        ref_var_name_to_orig_var_name.find(var_name);
    if (alias_name_check_iter != ref_var_name_to_orig_var_name.end()) {
      const std::string &original_var_name = alias_name_check_iter->second;
      return var_name_to_value.find(original_var_name) !=
             var_name_to_value.end();
    }
    return false;
  }

  // Add a reference alias, i.e. the two variable names are bound
  // to a single mlir::Value.
  // Note: chaining of aliasing is traced to the root var name:
  // e.g. we can support a, b (refers to a), then c refers to b.
  void add_alias(const std::string &orig_var_name,
                 const std::string &alias_var_name) {
    if (ref_var_name_to_orig_var_name.find(orig_var_name) !=
        ref_var_name_to_orig_var_name.end()) {
      // The original var name is an alias itself...
      const std::string &root_var_name =
          ref_var_name_to_orig_var_name[orig_var_name];
      ref_var_name_to_orig_var_name[alias_var_name] = root_var_name;
    } else {
      assert(var_name_to_value.find(orig_var_name) != var_name_to_value.end());
      ref_var_name_to_orig_var_name[alias_var_name] = orig_var_name;
    }
  }

  // Get the symbol (mlir::Value) taking into account potential alias chaining.
  mlir::Value get_symbol(const std::string &var_name) {
    auto iter = var_name_to_value.find(var_name);
    if (iter != var_name_to_value.end()) {
      return iter->second;
    }

    auto alias_iter = ref_var_name_to_orig_var_name.find(var_name);
    if (alias_iter != ref_var_name_to_orig_var_name.end()) {
      const std::string &root_var_name = alias_iter->second;
      assert(var_name_to_value.find(root_var_name) != var_name_to_value.end());
      return var_name_to_value[root_var_name];
    }
    printErrorMessage("Unknown symbol '" + var_name + "'.");
    return mlir::Value();
  }

  void add_or_update_symbol(const std::string &var_name, mlir::Value value) {
    var_name_to_value[var_name] = value;
  }

  // Compatible w/ a raw map (assuming the variable is original/root)
  mlir::Value &operator[](const std::string &var_name) {
    return var_name_to_value[var_name];
  }

  mlir::Value &at(const std::string &var_name) {
    return var_name_to_value.at(var_name);
  }

  void insert(const std::pair<std::string, mlir::Value> &new_var) {
    var_name_to_value.insert(new_var);
  }

  std::map<std::string, mlir::Value>::iterator
  find(const std::string &var_name) {
    return var_name_to_value.find(var_name);
  }

  std::map<std::string, mlir::Value>::size_type
  count(const std::string &var_name) const {
    return var_name_to_value.count(var_name);
  }

private:
  std::map<std::string, mlir::Value> var_name_to_value;
  // By reference var name aliasing map:
  // track a variable name representing references to the original mlir::Value,
  // e.g. qubit aliasing from slicing.
  std::unordered_map<std::string, std::string> ref_var_name_to_orig_var_name;
};

using ConstantIntegerTable =
    std::map<std::pair<std::uint64_t, int>, mlir::Value>;

@@ -187,8 +282,7 @@ class ScopedSymbolTable {

  bool has_symbol(const std::string variable_name, const std::size_t scope) {
    for (int i = scope; i >= 0; i--) { // nasty bug, auto instead of int...
      if (!scoped_symbol_tables[i].empty() &&
          scoped_symbol_tables[i].count(variable_name)) {
      if (scoped_symbol_tables[i].has_symbol(variable_name)) {
        return true;
      }
    }
@@ -196,6 +290,15 @@ class ScopedSymbolTable {
    return false;
  }

  void add_symbol_ref_alias(const std::string &orig_variable_name,
                            const std::string &alias_ref_variable_name) {
    // Sanity check for debug
    assert(has_symbol(orig_variable_name));
    assert(!has_symbol(alias_ref_variable_name));
    scoped_symbol_tables[current_scope].add_alias(orig_variable_name,
                                                  alias_ref_variable_name);
  }

  SymbolTable& get_global_symbol_table() { return scoped_symbol_tables[0]; }

  template <typename OpTy>
@@ -227,8 +330,8 @@ class ScopedSymbolTable {
  mlir::Value get_symbol(const std::string variable_name,
                         const std::size_t scope) {
    for (auto i = scope; i >= 0; i--) {
      if (scoped_symbol_tables[i].count(variable_name)) {
        return scoped_symbol_tables[i][variable_name];
      if (scoped_symbol_tables[i].has_symbol(variable_name)) {
        return scoped_symbol_tables[i].get_symbol(variable_name);
      }
    }

@@ -305,6 +408,19 @@ class ScopedSymbolTable {
    return current_scope >= 1 ? current_scope - 1 : 0;
  }

  // Util to construct a symbol name for qubit within an array (qreg)
  // This is to make sure we have a consitent symbol naming convention (for SSA tracking).
  std::string array_qubit_symbol_name(const std::string &qreg_name,
                                      const std::string &index_str) {
    // Sanity check: we should have added the qreg var to the symbol table.
    assert(has_symbol(qreg_name));
    // Use '%' separator to prevent name clashes with user-defined variables
    return qreg_name + '%' + index_str;
  }
  std::string array_qubit_symbol_name(const std::string &qreg_name, int index) {
    return array_qubit_symbol_name(qreg_name, std::to_string(index));
  }

  ~ScopedSymbolTable() {}
};
}  // namespace qcor
 No newline at end of file
+81 −4
Original line number Diff line number Diff line
@@ -89,6 +89,24 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
                  builder);
              auto src_idx = get_or_create_constant_integer_value(
                  idx, location, builder.getI64Type(), symbol_table, builder);

              // Put the *alias* qubit (alias-name + index) into the symbol
              // table: mapped to the original qubit:
              const std::string alias_qubit_var_name =
                  symbol_table.array_qubit_symbol_name(in_aliasName, counter);
              const std::string original_qubit_var_name =
                  symbol_table.array_qubit_symbol_name(allocated_variable, idx);
              if (!symbol_table.has_symbol(original_qubit_var_name)) {
                // This original qubit has never been extracted...
                // Just create an extract and cache to the symbol table
                mlir::Value original_qubit_val =
                    builder.create<mlir::quantum::ExtractQubitOp>(
                        location, qubit_type, allocated_symbol, src_idx);
                symbol_table.add_symbol(original_qubit_var_name,
                                        original_qubit_val);
              }
              symbol_table.add_symbol_ref_alias(original_qubit_var_name,
                                                alias_qubit_var_name);
              ++counter;

              builder.create<mlir::quantum::AssignQubitOp>(
@@ -176,10 +194,36 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
            };
            const auto new_size =
                slice_size_calc(orig_size, range_start, range_step, range_stop);
            // std::cout << "Adding symbol 2 " << in_aliasName << "\n";

            symbol_table.add_symbol(in_aliasName, array_slice,
                                    {std::to_string(new_size)});
            // std::cout << "Adding symbol 2 " << in_aliasName << "\n";
            for (int dest_idx = 0; dest_idx < new_size; ++dest_idx) {
              const int64_t range_start_pos =
                  range_start >= 0 ? range_start : orig_size + range_start;
              const int source_idx = range_start_pos + dest_idx * range_step;
              assert(source_idx >= 0);
              // Put the *alias* qubit (alias-name + index) into the symbol
              // table: mapped to the original qubit:
              const std::string alias_qubit_var_name =
                  symbol_table.array_qubit_symbol_name(in_aliasName, dest_idx);
              const std::string original_qubit_var_name =
                  symbol_table.array_qubit_symbol_name(allocated_variable,
                                                       source_idx);
              if (!symbol_table.has_symbol(original_qubit_var_name)) {
                // This original qubit has never been extracted...
                // Just create an extract and cache to the symbol table
                mlir::Value original_qubit_val =
                    builder.create<mlir::quantum::ExtractQubitOp>(
                        location, qubit_type, allocated_symbol,
                        get_or_create_constant_integer_value(
                            source_idx, location, builder.getI64Type(),
                            symbol_table, builder));
                symbol_table.add_symbol(original_qubit_var_name,
                                        original_qubit_val);
              }
              symbol_table.add_symbol_ref_alias(original_qubit_var_name,
                                                alias_qubit_var_name);
            }
          } else {
            printErrorMessage("Could not parse the alias statement.",
                              in_indexIdentifierContext);
@@ -231,13 +275,46 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
              builder.create<mlir::quantum::ArrayConcatOp>(
                  location, array_type, first_reg_symbol, second_reg_symbol);
          const auto new_size = first_reg_size + second_reg_size;
          symbol_table.add_symbol(in_aliasName, array_concat,
                                  {std::to_string(new_size)});
          // std::cout << "Concatenate " << lhs_temp_var << "[" <<
          // first_reg_size
          //           << "] with " << rhs_temp_var << "[" << second_reg_size
          //           << "] -> " << in_aliasName << "[" << new_size << "].\n";

          symbol_table.add_symbol(in_aliasName, array_concat,
                                  {std::to_string(new_size)});
          // Add the qubit alias for the concatenated registers:
          for (int dest_idx = 0; dest_idx < new_size; ++dest_idx) {
            const std::string source_reg_name =
                dest_idx < first_reg_size ? lhs_temp_var : rhs_temp_var;
            const int source_idx = dest_idx < first_reg_size
                                       ? dest_idx
                                       : (dest_idx - first_reg_size);
            mlir::Value source_qreg_value = dest_idx < first_reg_size
                                                ? first_reg_symbol
                                                : second_reg_symbol;
            assert(source_idx >= 0);
            // Put the *alias* qubit (alias-name + index) into the symbol
            // table: mapped to the original qubit:
            const std::string alias_qubit_var_name =
                symbol_table.array_qubit_symbol_name(in_aliasName, dest_idx);
            const std::string original_qubit_var_name =
                symbol_table.array_qubit_symbol_name(source_reg_name,
                                                     source_idx);
            if (!symbol_table.has_symbol(original_qubit_var_name)) {
              // This original qubit has never been extracted...
              // Just create an extract and cache to the symbol table
              mlir::Value original_qubit_val =
                  builder.create<mlir::quantum::ExtractQubitOp>(
                      location, qubit_type, source_qreg_value,
                      get_or_create_constant_integer_value(
                          source_idx, location, builder.getI64Type(),
                          symbol_table, builder));
              symbol_table.add_symbol(original_qubit_var_name,
                                      original_qubit_val);
            }
            symbol_table.add_symbol_ref_alias(original_qubit_var_name,
                                              alias_qubit_var_name);
          }
        } else {
          printErrorMessage("Could not parse the alias statement.",
                            in_indexIdentifierContext);
Loading