Commit d6e85ae9 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on lowering captured vars to tuple (packing vars to a QIR tuple)



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent fff5c81c
Loading
Loading
Loading
Loading
+28 −4
Original line number Diff line number Diff line
@@ -95,6 +95,21 @@ mlir::Value create_capture_callable_gen(
                                              wrapped_func, unpackOp.result());
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
  moduleOp.push_back(function_op);
  
  // !! We only ever invoke the body functor, create dummy functors for adj/ctrl
  for (const auto &suffix :
       {"__adj__wrapper", "__ctl__wrapper", "__ctladj__wrapper"}) {
    builder.restoreInsertionPoint(main_block);
    const std::string temp_fn_name = func_name + suffix;
    mlir::FuncOp fn_op(
        mlir::FuncOp::create(builder.getUnknownLoc(), temp_fn_name, func_type));
    fn_op.setVisibility(mlir::SymbolTable::Visibility::Private);
    auto &entryBlock = *fn_op.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    moduleOp.push_back(fn_op);
  }

  builder.restoreInsertionPoint(main_block);
  auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>(
      builder.getUnknownLoc(), callable_type,
@@ -132,12 +147,17 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    std::vector<mlir::Type> argument_types;
    std::vector<std::string> argument_names;
    std::vector<mlir::Value> argument_values;

    // Narrow the list of supported types for tuple unpack...
    // We don't support all types atm.
    for (auto &[k, v] : all_vars) {
      // QIR types and Float (rotation angles)
      if (v.getType().isa<mlir::OpaqueType>() ||
          v.getType().isa<mlir::FloatType>()) {
        argument_names.emplace_back(k);
        argument_values.emplace_back(v);
        argument_types.emplace_back(v.getType());
      }
    }

    // Use the ANTLR node ptr (hex) as id for this temp. function
    const auto toString = [](auto *antr_node) {
@@ -165,6 +185,10 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    symbol_table.exit_scope();
    symbol_table.add_seen_function(tmp_func_name, function);
    symbol_table.set_last_created_block(nullptr);
    for (int i = 0; i < arguments.size(); ++i) {
      symbol_table.replace_symbol(symbol_table.get_symbol(argument_names[i]),
                                  argument_values[i]);
    }
    m_module.push_back(function);

    auto then_body_callable = create_capture_callable_gen(
+114 −11
Original line number Diff line number Diff line
@@ -52,6 +52,9 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite(
      tuple_struct_type_list.push_back(mlir::FloatType::getF64(context));
    } else if (result.getType().isa<mlir::IntegerType>()) {
      tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64));
    } else {
      std::cout << "WE DON'T SUPPORT TUPLE UNPACK FOR THE TYPE\n";
      exit(0);
    }
  }

@@ -126,12 +129,12 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
      value_1_const,
      /*alignment=*/0);

   
  const std::string kernel_name = create_callable_op.functors().str();
  const std::string BODY_WRAPPER_NAME = kernel_name + "__body__wrapper";
  const std::string ADJOINT_WRAPPER_NAME = kernel_name + "__adj__wrapper";
  const std::string CTRL_WRAPPER_NAME = kernel_name + "__ctl__wrapper";
  const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper";
  const std::string CTRL_ADJOINT_WRAPPER_NAME =
      kernel_name + "__ctladj__wrapper";

  const std::vector<mlir::Value> functor_ptr_values{
      // Base
@@ -180,9 +183,9 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
    //                           TuplePtr capture)
    auto create_callable_ftype = LLVM::LLVMFunctionType::get(
        callable_return_type,
        llvm::ArrayRef<Type>{LLVM::LLVMPointerType::get(callable_entry_fn_array_type),
                             LLVM::LLVMPointerType::get(callback_fn_array_type),
                             tuple_type},
        llvm::ArrayRef<Type>{
            LLVM::LLVMPointerType::get(callable_entry_fn_array_type),
            LLVM::LLVMPointerType::get(callback_fn_array_type), tuple_type},
        false);

    // Insert the function declaration
@@ -196,12 +199,112 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
  // Callbacks and captured tuple ==> null
  mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>(
      location, LLVM::LLVMPointerType::get(callback_fn_array_type));
  mlir::Value tuple_nullPtr =
      rewriter.create<LLVM::NullOp>(location, tuple_type);
  mlir::Value capture_tuple_ptr = [&]() {
    if (create_callable_op.captures().empty()) {
      auto op = rewriter.create<LLVM::NullOp>(location, tuple_type);
      return op.res();
    } else {
      mlir::SmallVector<mlir::Type> tuple_struct_type_list;
      size_t tuple_size_in_bytes = 0;
      for (const auto &captured_var : create_callable_op.captures()) {
        if (captured_var.getType().isa<mlir::OpaqueType>() &&
            captured_var.getType().cast<mlir::OpaqueType>().getTypeData() ==
                "Array") {
          tuple_struct_type_list.push_back(
              LLVM::LLVMPointerType::get(get_quantum_type("Array", context)));
          tuple_size_in_bytes += sizeof(void *);
        } else if (captured_var.getType().isa<mlir::OpaqueType>() &&
                   captured_var.getType()
                           .cast<mlir::OpaqueType>()
                           .getTypeData() == "Qubit") {
          tuple_struct_type_list.push_back(
              LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)));
          tuple_size_in_bytes += sizeof(void *);
        } else if (captured_var.getType().isa<mlir::OpaqueType>() &&
                   captured_var.getType()
                           .cast<mlir::OpaqueType>()
                           .getTypeData() == "Tuple") {
          tuple_struct_type_list.push_back(
              LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context)));
          tuple_size_in_bytes += sizeof(void *);
        } else if (captured_var.getType().isa<mlir::FloatType>()) {
          tuple_struct_type_list.push_back(mlir::FloatType::getF64(context));
          tuple_size_in_bytes += sizeof(double);
        } else if (captured_var.getType().isa<mlir::IntegerType>()) {
          tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64));
          tuple_size_in_bytes += sizeof(int64_t);
        } else {
          std::cout << "WE DON'T SUPPORT TUPLE PACK FOR THE TYPE\n";
          exit(0);
        }
      }
      mlir::Value tuple_size_value = rewriter.create<LLVM::ConstantOp>(
          location, mlir::IntegerType::get(rewriter.getContext(), 64),
          rewriter.getIntegerAttr(
              mlir::IntegerType::get(rewriter.getContext(), 64),
              tuple_size_in_bytes));

      // Tuple create signature: TuplePtr __quantum__rt__tuple_create(int64_t
      // size)
      FlatSymbolRefAttr tuple_create_symbol_ref = [&]() {
        const std::string qir_tuple_create_fn_name =
            "__quantum__rt__tuple_create";
        if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(
                qir_tuple_create_fn_name)) {
          return SymbolRefAttr::get(qir_tuple_create_fn_name, context);
        } else {
          auto ftype = LLVM::LLVMFunctionType::get(
              tuple_type,
              llvm::ArrayRef<Type>{mlir::IntegerType::get(context, 64)}, false);

          // Insert the function declaration
          PatternRewriter::InsertionGuard insertGuard(rewriter);
          rewriter.setInsertionPointToStart(parentModule.getBody());
          rewriter.create<LLVM::LLVMFuncOp>(parentModule->getLoc(),
                                            qir_tuple_create_fn_name, ftype);
          return mlir::SymbolRefAttr::get(qir_tuple_create_fn_name, context);
        }
      }();
      auto createTupleCallOp = rewriter.create<mlir::CallOp>(
          location, tuple_create_symbol_ref, tuple_type,
          ArrayRef<Value>({tuple_size_value}));
      mlir::Value tuplePtr = createTupleCallOp.getResult(0);

      // Store to tuple:
      auto tuple_struct_type =
          LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list);
      auto structPtr =
          rewriter
              .create<LLVM::BitcastOp>(
                  location, LLVM::LLVMPointerType::get(tuple_struct_type),
                  tuplePtr)
              .res();

      mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>(
          location, IntegerType::get(rewriter.getContext(), 32),
          rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
      for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) {
        mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>(
            location, IntegerType::get(rewriter.getContext(), 32),
            rewriter.getIntegerAttr(rewriter.getIndexType(), idx));
        auto field_ptr =
            rewriter
                .create<LLVM::GEPOp>(
                    location,
                    LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]),
                    structPtr, ArrayRef<Value>({zero_cst, idx_cst}))
                .res();
        auto store_op = rewriter.create<LLVM::StoreOp>(
            location, create_callable_op.captures()[idx], field_ptr);
      }

      return tuplePtr;
    }
  }();
  auto createCallableCallOp = rewriter.create<mlir::CallOp>(
      location, qir_symbol_ref, callable_return_type,
      ArrayRef<Value>(
          {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr}));
          {callable_entry_fn_array, callbacks_nullPtr, capture_tuple_ptr}));
  rewriter.replaceOp(op, createCallableCallOp.getResult(0));
  return success();
}