Commit 7fb09965 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

More proper support for SSA value semantics in qreg alias handler

Related to https://github.com/ORNL-QCI/qcor/issues/186



Enhance the SymbolTable type to track symbol aliasing in a 'by-ref' manner.

We want to bind qubit var name (qreg_name + index) to the true SSA value to be tracked by quantum instruction handler.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 3b69a4ff
Loading
Loading
Loading
Loading
+109 −6
Original line number Diff line number Diff line
@@ -11,7 +11,102 @@
#include "mlir/IR/BuiltinTypes.h"

namespace qcor {
using SymbolTable = std::map<std::string, mlir::Value>;
// using SymbolTable = std::map<std::string, mlir::Value>;
struct SymbolTable {
  std::map<std::string, mlir::Value>::iterator begin() {
    return var_name_to_value.begin();
  }
  std::map<std::string, mlir::Value>::iterator end() {
    return var_name_to_value.end();
  }

  // Check if we have this symbol:
  // If this is a *root* (master) symbol, backed by a mlir::Value
  // or this is an alias (by reference), normally only for Qubits (SSA values)
  bool has_symbol(const std::string &var_name) {
    if (var_name_to_value.find(var_name) != var_name_to_value.end()) {
      return true;
    }
    const auto alias_name_check_iter =
        ref_var_name_to_orig_var_name.find(var_name);
    if (alias_name_check_iter != ref_var_name_to_orig_var_name.end()) {
      const std::string &original_var_name = alias_name_check_iter->second;
      return var_name_to_value.find(original_var_name) !=
             var_name_to_value.end();
    }
    return false;
  }

  // Add a reference alias, i.e. the two variable names are bound
  // to a single mlir::Value.
  // Note: chaining of aliasing is traced to the root var name:
  // e.g. we can support a, b (refers to a), then c refers to b.
  void add_alias(const std::string &orig_var_name,
                 const std::string &alias_var_name) {
    if (ref_var_name_to_orig_var_name.find(orig_var_name) !=
        ref_var_name_to_orig_var_name.end()) {
      // The original var name is an alias itself...
      const std::string &root_var_name =
          ref_var_name_to_orig_var_name[orig_var_name];
      ref_var_name_to_orig_var_name[alias_var_name] = root_var_name;
    } else {
      assert(var_name_to_value.find(orig_var_name) != var_name_to_value.end());
      ref_var_name_to_orig_var_name[alias_var_name] = orig_var_name;
    }
  }

  // Get the symbol (mlir::Value) taking into account potential alias chaining.
  mlir::Value get_symbol(const std::string &var_name) {
    auto iter = var_name_to_value.find(var_name);
    if (iter != var_name_to_value.end()) {
      return iter->second;
    }

    auto alias_iter = ref_var_name_to_orig_var_name.find(var_name);
    if (alias_iter != ref_var_name_to_orig_var_name.end()) {
      const std::string &root_var_name = alias_iter->second;
      assert(var_name_to_value.find(root_var_name) != var_name_to_value.end());
      return var_name_to_value[root_var_name];
    }
    printErrorMessage("Unknown symbol '" + var_name + "'.");
    return mlir::Value();
  }

  void add_or_update_symbol(const std::string &var_name, mlir::Value value) {
    var_name_to_value[var_name] = value;
  }

  // Compatible w/ a raw map (assuming the variable is original/root)
  mlir::Value &operator[](const std::string &var_name) {
    return var_name_to_value[var_name];
  }

  mlir::Value &at(const std::string &var_name) {
    return var_name_to_value.at(var_name);
  }

  void insert(const std::pair<std::string, mlir::Value> &new_var) {
    var_name_to_value.insert(new_var);
  }

  std::map<std::string, mlir::Value>::iterator
  find(const std::string &var_name) {
    return var_name_to_value.find(var_name);
  }

  std::map<std::string, mlir::Value>::size_type
  count(const std::string &var_name) const {
    return var_name_to_value.count(var_name);
  }

private:
  std::map<std::string, mlir::Value> var_name_to_value;
  // By reference var name aliasing map:
  // track a variable name representing references to the original mlir::Value,
  // e.g. qubit aliasing from slicing.
  std::unordered_map<std::string, std::string> ref_var_name_to_orig_var_name;
};

using ConstantIntegerTable =
    std::map<std::pair<std::uint64_t, int>, mlir::Value>;

@@ -187,8 +282,7 @@ class ScopedSymbolTable {

  bool has_symbol(const std::string variable_name, const std::size_t scope) {
    for (int i = scope; i >= 0; i--) { // nasty bug, auto instead of int...
      if (!scoped_symbol_tables[i].empty() &&
          scoped_symbol_tables[i].count(variable_name)) {
      if (scoped_symbol_tables[i].has_symbol(variable_name)) {
        return true;
      }
    }
@@ -196,6 +290,15 @@ class ScopedSymbolTable {
    return false;
  }

  void add_symbol_ref_alias(const std::string &orig_variable_name,
                            const std::string &alias_ref_variable_name) {
    // Sanity check for debug
    assert(has_symbol(orig_variable_name));
    assert(!has_symbol(alias_ref_variable_name));
    scoped_symbol_tables[current_scope].add_alias(orig_variable_name,
                                                  alias_ref_variable_name);
  }

  SymbolTable& get_global_symbol_table() { return scoped_symbol_tables[0]; }

  template <typename OpTy>
@@ -227,8 +330,8 @@ class ScopedSymbolTable {
  mlir::Value get_symbol(const std::string variable_name,
                         const std::size_t scope) {
    for (auto i = scope; i >= 0; i--) {
      if (scoped_symbol_tables[i].count(variable_name)) {
        return scoped_symbol_tables[i][variable_name];
      if (scoped_symbol_tables[i].has_symbol(variable_name)) {
        return scoped_symbol_tables[i].get_symbol(variable_name);
      }
    }

+18 −0
Original line number Diff line number Diff line
@@ -89,6 +89,24 @@ antlrcpp::Any qasm3_visitor::visitAliasStatement(
                  builder);
              auto src_idx = get_or_create_constant_integer_value(
                  idx, location, builder.getI64Type(), symbol_table, builder);

              // Put the *alias* qubit (alias-name + index) into the symbol
              // table: mapped to the original qubit:
              const std::string alias_qubit_var_name =
                  in_aliasName + std::to_string(counter);
              const std::string original_qubit_var_name =
                  allocated_variable + std::to_string(idx);
              if (!symbol_table.has_symbol(original_qubit_var_name)) {
                // This original qubit has never been extracted...
                // Just create an extract and cache to the symbol table
                mlir::Value original_qubit_val =
                    builder.create<mlir::quantum::ExtractQubitOp>(
                        location, qubit_type, allocated_symbol, src_idx);
                symbol_table.add_symbol(original_qubit_var_name,
                                        original_qubit_val);
              }
              symbol_table.add_symbol_ref_alias(original_qubit_var_name,
                                                alias_qubit_var_name);
              ++counter;

              builder.create<mlir::quantum::AssignQubitOp>(