Commit 13ebdef0 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Implemented all Callable functor table wrappers for QASM3 sub-routines



Using Adj/Ctrl region annotations accordingly.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent f27a0eeb
Loading
Loading
Loading
Loading
+139 −37
Original line number Diff line number Diff line
#include "qasm3_visitor.hpp"

namespace {
void add_body_wrapper(mlir::OpBuilder &builder, const std::string &func_name,
// Helper to generate a QIR callable wrapper for a QASM3 subroutine:
// A Callable is constructed from a functor table (array of size 4)
// for body (base), adjoint, controlled, and controlled adjoint functors
// that all have signature of void(Tuple, Tuple, Tuple).
// This method generates those 4 wrappers as well as the function to construct
// the Callable.
void add_callable_gen(mlir::OpBuilder &builder, const std::string &func_name,
                      mlir::ModuleOp &moduleOp, mlir::FuncOp &wrapped_func) {
  // define internal void @body__wrapper(%Tuple* %capture-tuple, %Tuple*
  // %arg-tuple, %Tuple* %result-tuple)
  const std::string wrapper_fn_name = func_name + "__body__wrapper";
  auto main_block = builder.saveInsertionPoint();
  auto context = builder.getContext();
  llvm::StringRef tuple_type_name("Tuple");
  auto main_block = builder.saveInsertionPoint();
  mlir::Identifier dialect = mlir::Identifier::get("quantum", context);
  llvm::StringRef tuple_type_name("Tuple");
  auto tuple_type = mlir::OpaqueType::get(context, dialect, tuple_type_name);
  llvm::StringRef array_type_name("Array");
  auto array_type = mlir::OpaqueType::get(context, dialect, array_type_name);
  llvm::StringRef callable_type_name("Callable");
  auto callable_type =
      mlir::OpaqueType::get(context, dialect, callable_type_name);
  llvm::StringRef qubit_type_name("Qubit");
  auto qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name);

  const std::vector<mlir::Type> argument_types{tuple_type, tuple_type,
                                               tuple_type};
  auto func_type = builder.getFunctionType(argument_types, llvm::None);
  auto proto =
      mlir::FuncOp::create(builder.getUnknownLoc(), wrapper_fn_name, func_type);
  mlir::FuncOp function_op(proto);

  const std::string BODY_WRAPPER_SUFFIX = "__body__wrapper";
  const std::string ADJOINT_WRAPPER_SUFFIX = "__adj__wrapper";
  const std::string CTRL_WRAPPER_SUFFIX = "__ctl__wrapper";
  const std::string CTRL_ADJOINT_WRAPPER_SUFFIX = "__ctladj__wrapper";

  std::vector<mlir::FuncOp> all_wrapper_funcs;
  {
    // Body wrapper:
    const std::string wrapper_fn_name = func_name + BODY_WRAPPER_SUFFIX;
    mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
                                                  wrapper_fn_name, func_type));
    function_op.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *function_op.addEntryBlock();

    builder.setInsertionPointToStart(&entryBlock);
    auto arguments = entryBlock.getArguments();
    assert(arguments.size() == 3);
@@ -32,23 +46,111 @@ void add_body_wrapper(mlir::OpBuilder &builder, const std::string &func_name,
    mlir::TypeRange arg_types(fn_type.getInputs());
    auto unpackOp = builder.create<mlir::quantum::TupleUnpackOp>(
        builder.getUnknownLoc(), arg_types, arg_tuple);
  auto call_op = builder.create<mlir::CallOp>(builder.getUnknownLoc(),
                                              wrapped_func, unpackOp.result());
    auto call_op = builder.create<mlir::CallOp>(
        builder.getUnknownLoc(), wrapped_func, unpackOp.result());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    moduleOp.push_back(function_op);
    all_wrapper_funcs.emplace_back(function_op);
  }

  {
    // Adjoint wrapper:
    const std::string wrapper_fn_name = func_name + ADJOINT_WRAPPER_SUFFIX;
    mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
                                                  wrapper_fn_name, func_type));
    function_op.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *function_op.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    // Wrap the call to the body wrapperin StartAdjointURegion and
    // EndAdjointURegion
    builder.create<mlir::quantum::StartAdjointURegion>(builder.getUnknownLoc());
    mlir::FuncOp body_wrapper = all_wrapper_funcs[0];
    // Forward tuple arguments to the body (will unpack there)
    auto call_op = builder.create<mlir::CallOp>(
        builder.getUnknownLoc(), body_wrapper, entryBlock.getArguments());
    builder.create<mlir::quantum::EndAdjointURegion>(builder.getUnknownLoc());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    moduleOp.push_back(function_op);
    all_wrapper_funcs.emplace_back(function_op);
  }
  {
    // Controlled wrapper:
    const std::string wrapper_fn_name = func_name + CTRL_WRAPPER_SUFFIX;
    mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
                                                  wrapper_fn_name, func_type));
    function_op.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *function_op.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    auto arguments = entryBlock.getArguments();
    assert(arguments.size() == 3);
    mlir::Value arg_tuple = arguments[1];
    auto fn_type = wrapped_func.getType().cast<mlir::FunctionType>();
    // Unpack to Array + Tuple (Array = controlled bits)
    // { Array + { Body Tuple } }
    // FIXME: currently, we can only handle single-qubit control
    // TODO: update EndCtrlURegion to take an array of qubits.
    mlir::TypeRange arg_types({array_type, tuple_type});
    auto unpackOp = builder.create<mlir::quantum::TupleUnpackOp>(
        builder.getUnknownLoc(), arg_types, arg_tuple);
    mlir::FuncOp body_wrapper = all_wrapper_funcs[0];
    mlir::Value control_array = unpackOp.result()[0];
    mlir::Value body_arg_tuple = unpackOp.result()[1];

    // Extract the control qubit:
    mlir::Value qubit_idx = builder.create<mlir::ConstantOp>(
        builder.getUnknownLoc(),
        mlir::IntegerAttr::get(builder.getI64Type(), 0));
    mlir::Value ctrl_qubit = builder.create<mlir::quantum::ExtractQubitOp>(
        builder.getUnknownLoc(), qubit_type, control_array, qubit_idx);

    // Call the body wrapped in StartCtrlURegion/EndCtrlURegion
    builder.create<mlir::quantum::StartCtrlURegion>(builder.getUnknownLoc());
    auto call_op = builder.create<mlir::CallOp>(
        builder.getUnknownLoc(), body_wrapper,
        llvm::ArrayRef<mlir::Value>(
            {arguments[0], body_arg_tuple, arguments[2]}));
    builder.create<mlir::quantum::EndCtrlURegion>(builder.getUnknownLoc(),
                                                  ctrl_qubit);
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    moduleOp.push_back(function_op);
    all_wrapper_funcs.emplace_back(function_op);
  }
  {
    // Controlled Adjoint wrapper:
    const std::string wrapper_fn_name = func_name + CTRL_ADJOINT_WRAPPER_SUFFIX;
    mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
                                                  wrapper_fn_name, func_type));
    function_op.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *function_op.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    // Wrap the call to the ctrl wrapper wrapped in StartAdjointURegion and
    // EndAdjointURegion
    builder.create<mlir::quantum::StartAdjointURegion>(builder.getUnknownLoc());
    mlir::FuncOp ctrl_wrapper = all_wrapper_funcs[2];
    // Forward tuple arguments to the controlled (will unpack there)
    auto call_op = builder.create<mlir::CallOp>(
        builder.getUnknownLoc(), ctrl_wrapper, entryBlock.getArguments());
    builder.create<mlir::quantum::EndAdjointURegion>(builder.getUnknownLoc());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    moduleOp.push_back(function_op);
    all_wrapper_funcs.emplace_back(function_op);
  }

  // Add a function to create the callable wrapper for this kernel
  auto create_callable_func_type = builder.getFunctionType({}, callable_type);
  const std::string create_callable_fn_name = func_name + "__callable";
  auto create_callable_func_proto =
      mlir::FuncOp::create(builder.getUnknownLoc(), create_callable_fn_name, create_callable_func_type);
      mlir::FuncOp::create(builder.getUnknownLoc(), create_callable_fn_name,
                           create_callable_func_type);
  mlir::FuncOp create_callable_function_op(create_callable_func_proto);
  auto &create_callable_entryBlock = *create_callable_function_op.addEntryBlock();
  auto &create_callable_entryBlock =
      *create_callable_function_op.addEntryBlock();
  builder.setInsertionPointToStart(&create_callable_entryBlock);
  auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>(
      builder.getUnknownLoc(), callable_type,
      builder.getSymbolRefAttr(function_op));
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), callable_create_op.callable());
      builder.getSymbolRefAttr(wrapped_func));
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                 callable_create_op.callable());
  moduleOp.push_back(create_callable_function_op);
  builder.restoreInsertionPoint(main_block);
}
@@ -257,7 +359,7 @@ antlrcpp::Any qasm3_visitor::visitSubroutineDefinition(
  m_module.push_back(interop);

  // TODO: add a compile switch to enable/disable this export:
  add_body_wrapper(builder, subroutine_name, m_module, function);
  add_callable_gen(builder, subroutine_name, m_module, function);
  return 0;
}

+23 −9
Original line number Diff line number Diff line
@@ -43,6 +43,11 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite(
                   "Qubit") {
      tuple_struct_type_list.push_back(
          LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)));
    } else if (result.getType().isa<mlir::OpaqueType>() &&
               result.getType().cast<mlir::OpaqueType>().getTypeData() ==
                   "Tuple") {
      tuple_struct_type_list.push_back(
          LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context)));
    } else if (result.getType().isa<mlir::FloatType>()) {
      tuple_struct_type_list.push_back(mlir::FloatType::getF64(context));
    } else if (result.getType().isa<mlir::IntegerType>()) {
@@ -106,8 +111,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
      llvm::ArrayRef<Type>{tuple_type,
                           IntegerType::get(rewriter.getContext(), 32)},
      false);
  FlatSymbolRefAttr symbol_ref =
      SymbolRefAttr::get(create_callable_op.functors(), context);
  
  auto callable_entry_fn_array_type = LLVM::LLVMArrayType::get(
      LLVM::LLVMPointerType::get(callable_entry_ftype), 4);
  auto callback_fn_array_type = LLVM::LLVMArrayType::get(
@@ -122,20 +126,30 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
      value_1_const,
      /*alignment=*/0);

   
  const std::string kernel_name = create_callable_op.functors().str();
  const std::string BODY_WRAPPER_NAME = kernel_name + "__body__wrapper";
  const std::string ADJOINT_WRAPPER_NAME = kernel_name + "__adj__wrapper";
  const std::string CTRL_WRAPPER_NAME = kernel_name + "__ctl__wrapper";
  const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper";

  const std::vector<mlir::Value> functor_ptr_values{
      // Base
      rewriter.create<LLVM::AddressOfOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype),
          symbol_ref),
          SymbolRefAttr::get(BODY_WRAPPER_NAME.c_str(), context)),
      // Adjoint
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
      rewriter.create<LLVM::AddressOfOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype),
          SymbolRefAttr::get(ADJOINT_WRAPPER_NAME.c_str(), context)),
      // Controlled
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
      rewriter.create<LLVM::AddressOfOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype),
          SymbolRefAttr::get(CTRL_WRAPPER_NAME.c_str(), context)),
      // Controlled Adjoint
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
      rewriter.create<LLVM::AddressOfOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype),
          SymbolRefAttr::get(CTRL_ADJOINT_WRAPPER_NAME.c_str(), context)),
  };

  mlir::Value zero_index = rewriter.create<LLVM::ConstantOp>(