Commit 433520cf authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Construct functor table array for Callable creation



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 5628ba25
Loading
Loading
Loading
Loading
+42 −13
Original line number Diff line number Diff line
@@ -103,21 +103,51 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
      false);
  FlatSymbolRefAttr symbol_ref =
      SymbolRefAttr::get(create_callable_op.functors(), context);
  // mlir::Value func_ptr = rewriter.create<LLVM::AddressOfOp>(
  //     location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref)
  auto callable_entry_fn_array_type = LLVM::LLVMArrayType::get(
      LLVM::LLVMPointerType::get(callable_entry_ftype), 4);
  auto callback_fn_array_type = LLVM::LLVMArrayType::get(
      LLVM::LLVMPointerType::get(capture_callback_ftype), 2);
  auto save_pt = rewriter.saveInsertionPoint();
  rewriter.setInsertionPointToStart(&parentModule.getRegion().getBlocks().front());
  const std::string functor_array_name =
      create_callable_op.functors().str() + "__Qops";
  auto fPtr_array = ArrayAttr::get({symbol_ref, symbol_ref, symbol_ref, symbol_ref} , context);
  auto fPtr_array_const_global = rewriter.create<LLVM::GlobalOp>(
      location, callable_entry_fn_array_type, /*isConstant=*/true, LLVM::Linkage::Internal,
      functor_array_name.c_str(), fPtr_array);
  rewriter.restoreInsertionPoint(save_pt);
  mlir::Value value_1_const = rewriter.create<LLVM::ConstantOp>(
      location, IntegerType::get(rewriter.getContext(), 64),
      rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
  mlir::Value callable_entry_fn_array = rewriter.create<LLVM::AllocaOp>(
      location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type),
      value_1_const,
      /*alignment=*/0);

  const std::vector<mlir::Value> functor_ptr_values{
      // Base
      rewriter.create<LLVM::AddressOfOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype),
          symbol_ref),
      // Adjoint
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
      // Controlled
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
      // Controlled Adjoint
      rewriter.create<LLVM::NullOp>(
          location, LLVM::LLVMPointerType::get(callable_entry_ftype)),
  };

  mlir::Value zero_index = rewriter.create<LLVM::ConstantOp>(
      location, IntegerType::get(rewriter.getContext(), 64),
      rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
  for (size_t func_idx = 0; func_idx < functor_ptr_values.size(); ++func_idx) {
    mlir::Value func_index_val = rewriter.create<LLVM::ConstantOp>(
        location, IntegerType::get(rewriter.getContext(), 64),
        rewriter.getIntegerAttr(rewriter.getIndexType(), func_idx));
    mlir::Value func_in_array_ptr = rewriter.create<LLVM::GEPOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(callable_entry_ftype)),
        callable_entry_fn_array, ArrayRef<Value>({zero_index, func_index_val}));
    rewriter.create<LLVM::StoreOp>(location, functor_ptr_values[func_idx],
                                   func_in_array_ptr);
  }

  auto callable_return_type =
      LLVM::LLVMPointerType::get(get_quantum_type("Callable", context));
@@ -144,8 +174,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
    qir_symbol_ref = mlir::SymbolRefAttr::get(qir_create_callable, context);
  }
  
  mlir::Value callable_entry_nullPtr = rewriter.create<LLVM::NullOp>(
      location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type));
  // Callbacks and captured tuple ==> null
  mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>(
      location, LLVM::LLVMPointerType::get(callback_fn_array_type));
  mlir::Value tuple_nullPtr =
@@ -153,7 +182,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite(
  auto createCallableCallOp = rewriter.create<mlir::CallOp>(
      location, qir_symbol_ref, callable_return_type,
      ArrayRef<Value>(
          {callable_entry_nullPtr, callbacks_nullPtr, tuple_nullPtr}));
          {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr}));
  rewriter.replaceOp(op, createCallableCallOp.getResult(0));
  return success();
}