Commit 7022ef25 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Wrap if body into a function in prep. for Callable conversion



Call-site context is captured implicitly via the symbol table.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent bb833c1f
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -168,4 +168,18 @@ ScopedSymbolTable::try_lookup_meas_result(const std::string &bit_var_name) {
  }
  return try_lookup_meas_result(get_symbol(bit_var_name));
}

std::unordered_map<std::string, mlir::Value>
ScopedSymbolTable::get_all_visible_symbols() {
  std::unordered_map<std::string, mlir::Value> all_symbols;
  for (int i = current_scope; i >= 0; i--) {
    for (auto &[k, v] : scoped_symbol_tables[i]) {
      // Don't override if seeing duplicated variables.
      if (all_symbols.find(k) == all_symbols.end()) {
        all_symbols[k] = v;
      }
    }
  }
  return all_symbols;
}
}  // namespace qcor
 No newline at end of file
+4 −0
Original line number Diff line number Diff line
@@ -179,6 +179,10 @@ public:
    }
  }

  // Get all visible symbols at the current scope.
  // Nearer symbols take precedence over further ones (if having the same name)
  std::unordered_map<std::string, mlir::Value> get_all_visible_symbols();

  // Create new scope symbol table
  // will push_back on scoped_symbol_tables;
  void enter_new_scope() {
+50 −5
Original line number Diff line number Diff line
@@ -81,11 +81,56 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    auto ifOp =
        builder.create<mlir::quantum::ConditionalOp>(location, meas_var.value(),
                                                     /*withElseRegion=*/false);
    auto body_builder = ifOp.getThenBodyBuilder();
    auto cached_builder = builder;
    builder = body_builder;
    
    // Strategy: we wrap the body as a Callable capturing
    // all avaiable variables at the current scope.
    // Note: we could detect which variables are used in the 
    // conditional block body to be included in the capture.
    auto all_vars = symbol_table.get_all_visible_symbols();
    auto main_block = builder.saveInsertionPoint();
    std::vector<mlir::Type> argument_types;
    std::vector<std::string> argument_names;
    std::vector<mlir::Value> argument_values;

    for (auto &[k, v] : all_vars) {
      argument_names.emplace_back(k);
      argument_values.emplace_back(v);
      argument_types.emplace_back(v.getType());
    }

    // Use the ANTLR node ptr (hex) as id for this temp. function
    const auto toString = [](auto *antr_node) {
      std::stringstream ss;
      ss << (void *)antr_node;
      return ss.str();
    };
    const std::string tmp_fun_name =
        "if_body_" + toString(context->programBlock(0));
    auto func_type = builder.getFunctionType(argument_types, llvm::None);
    auto proto =
        mlir::FuncOp::create(builder.getUnknownLoc(), tmp_fun_name, func_type);
    mlir::FuncOp function(proto);
    auto &entryBlock = *function.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    symbol_table.enter_new_scope();
    auto arguments = entryBlock.getArguments();
    for (int i = 0; i < arguments.size(); i++) {
      symbol_table.add_symbol(argument_names[i], arguments[i], {}, true);
    }
    visitChildren(context->programBlock(0));
    builder = cached_builder;
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    builder.restoreInsertionPoint(main_block);
    symbol_table.exit_scope();
    symbol_table.add_seen_function(tmp_fun_name, function);
    symbol_table.set_last_created_block(nullptr);
    m_module.push_back(function);

    // Create a call to the function:
    // FIXME: this should be wrapped as a callable...
    auto then_body_builder = ifOp.getThenBodyBuilder();
    auto body_call_op = then_body_builder.create<mlir::CallOp>(
        then_body_builder.getUnknownLoc(), function,
        llvm::makeArrayRef(argument_values));

    // Done
    return 0;