Commit e8331f06 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Generate callable with captures for If body



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 7022ef25
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -229,10 +229,14 @@ def TupleUnpackOp : QuantumOp<"tupleUnpack", []> {
}

def CreateCallableOp : QuantumOp<"createCallable", []> {
    let arguments = (ins FlatSymbolRefAttr:$functors);
    let arguments = (ins FlatSymbolRefAttr:$functors, Variadic<AnyType>:$captures);
    let results = (outs CallableType:$callable);
    let printer = [{  auto op = *this;
  p << "q.createCallable" << "(" << op.functors() << ") : " << op.callable().getType(); }];
      p << "q.createCallable" << "(" << op.functors() << ") ";
      if (!op.captures().empty()) {
        p << "capture " << op.captures();
      }
    }];
}

def YieldOp : QuantumOp<"yield", [NoSideEffect, Terminator]> {
+57 −4
Original line number Diff line number Diff line
@@ -57,6 +57,57 @@ std::optional<BitComparisonExpression> tryParseSimpleBooleanExpression(

  return std::nullopt;
}

// Callable running-off captured vars...
mlir::Value create_capture_callable_gen(
    mlir::OpBuilder &builder, const std::string &func_name,
    mlir::ModuleOp &moduleOp, mlir::FuncOp &wrapped_func,
    std::vector<mlir::Value> &captured_vars) {
  auto context = builder.getContext();
  auto main_block = builder.saveInsertionPoint();
  mlir::Identifier dialect = mlir::Identifier::get("quantum", context);
  llvm::StringRef tuple_type_name("Tuple");
  auto tuple_type = mlir::OpaqueType::get(context, dialect, tuple_type_name);
  llvm::StringRef array_type_name("Array");
  auto array_type = mlir::OpaqueType::get(context, dialect, array_type_name);
  llvm::StringRef callable_type_name("Callable");
  auto callable_type =
      mlir::OpaqueType::get(context, dialect, callable_type_name);
  llvm::StringRef qubit_type_name("Qubit");
  auto qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name);

  const std::vector<mlir::Type> argument_types{tuple_type, tuple_type,
                                               tuple_type};
  auto func_type = builder.getFunctionType(argument_types, llvm::None);
  const std::string BODY_WRAPPER_SUFFIX = "__body__wrapper";
  std::vector<mlir::FuncOp> all_wrapper_funcs;
  // Body wrapper:
  const std::string wrapper_fn_name = func_name + BODY_WRAPPER_SUFFIX;
  mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
                                                wrapper_fn_name, func_type));
  function_op.setVisibility(mlir::SymbolTable::Visibility::Private);
  auto &entryBlock = *function_op.addEntryBlock();
  builder.setInsertionPointToStart(&entryBlock);
  auto arguments = entryBlock.getArguments();
  assert(arguments.size() == 3);
  // Unpack from **captured** vars (not input args...)
  // i.e., Tuple # 0
  mlir::Value arg_tuple = arguments[0];
  auto fn_type = wrapped_func.getType().cast<mlir::FunctionType>();
  mlir::TypeRange arg_types(fn_type.getInputs());
  auto unpackOp = builder.create<mlir::quantum::TupleUnpackOp>(
      builder.getUnknownLoc(), arg_types, arg_tuple);
  auto call_op = builder.create<mlir::CallOp>(builder.getUnknownLoc(),
                                              wrapped_func, unpackOp.result());
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
  moduleOp.push_back(function_op);
  builder.restoreInsertionPoint(main_block);
  auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>(
      builder.getUnknownLoc(), callable_type,
      builder.getSymbolRefAttr(wrapped_func),
      /*captures*/ llvm::makeArrayRef(captured_vars));
  return callable_create_op;
}
} // namespace

namespace qcor {
@@ -70,7 +121,6 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
  auto bit_check_conditional =
      tryParseSimpleBooleanExpression(*conditional_expr);
  // Currently, we're only support If (not else yet)

  if (bit_check_conditional.has_value() &&
      context->programBlock().size() == 1 &&
      symbol_table.try_lookup_meas_result(bit_check_conditional->var_name)
@@ -104,12 +154,13 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
      ss << (void *)antr_node;
      return ss.str();
    };
    const std::string tmp_fun_name =
    const std::string tmp_func_name =
        "if_body_" + toString(context->programBlock(0));
    auto func_type = builder.getFunctionType(argument_types, llvm::None);
    auto proto =
        mlir::FuncOp::create(builder.getUnknownLoc(), tmp_fun_name, func_type);
        mlir::FuncOp::create(builder.getUnknownLoc(), tmp_func_name, func_type);
    mlir::FuncOp function(proto);
    function.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *function.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    symbol_table.enter_new_scope();
@@ -121,10 +172,12 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    builder.restoreInsertionPoint(main_block);
    symbol_table.exit_scope();
    symbol_table.add_seen_function(tmp_fun_name, function);
    symbol_table.add_seen_function(tmp_func_name, function);
    symbol_table.set_last_created_block(nullptr);
    m_module.push_back(function);

    auto then_body_callable = create_capture_callable_gen(
        builder, tmp_func_name, m_module, function, argument_values);
    // Create a call to the function:
    // FIXME: this should be wrapped as a callable...
    auto then_body_builder = ifOp.getThenBodyBuilder();
+2 −1
Original line number Diff line number Diff line
@@ -181,7 +181,8 @@ void add_callable_gen(mlir::OpBuilder &builder, const std::string &func_name,
  builder.setInsertionPointToStart(&create_callable_entryBlock);
  auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>(
      builder.getUnknownLoc(), callable_type,
      builder.getSymbolRefAttr(wrapped_func));
      builder.getSymbolRefAttr(wrapped_func), /*captures*/
      llvm::makeArrayRef(std::vector<mlir::Value>{}));
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                 callable_create_op.callable());
  moduleOp.push_back(create_callable_function_op);