Commit 5628ba25 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on Callable gen



- Create callable needs to be wrapped in a function as well since we cannot do some ops at global scope

- Lower to __quantum__rt__callable_create

Still need to figure out the constant array of function pointers in LLVM...

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent e6aed746
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -34,12 +34,22 @@ void add_body_wrapper(mlir::OpBuilder &builder, const std::string &func_name,
      builder.getUnknownLoc(), arg_types, arg_tuple);
  auto call_op = builder.create<mlir::CallOp>(builder.getUnknownLoc(),
                                              wrapped_func, unpackOp.result());
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
  moduleOp.push_back(function_op);
  builder.setInsertionPointToStart(&moduleOp.getRegion().getBlocks().front());
  
  // Add a function to create the callable wrapper for this kernel
  auto create_callable_func_type = builder.getFunctionType({}, callable_type);
  const std::string create_callable_fn_name = func_name + "__callable";
  auto create_callable_func_proto =
      mlir::FuncOp::create(builder.getUnknownLoc(), create_callable_fn_name, create_callable_func_type);
  mlir::FuncOp create_callable_function_op(create_callable_func_proto);
  auto &create_callable_entryBlock = *create_callable_function_op.addEntryBlock();
  builder.setInsertionPointToStart(&create_callable_entryBlock);
  auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>(
      builder.getUnknownLoc(), callable_type,
      builder.getSymbolRefAttr(function_op));
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), callable_create_op.callable());
  moduleOp.push_back(create_callable_function_op);
  builder.restoreInsertionPoint(main_block);
}
}; // namespace
+71 −0
Original line number Diff line number Diff line
@@ -84,6 +84,77 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite(
LogicalResult CreateCallableOpLowering::matchAndRewrite(
    Operation *op, ArrayRef<Value> operands,
    ConversionPatternRewriter &rewriter) const {
  ModuleOp parentModule = op->getParentOfType<ModuleOp>();
  auto location = parentModule->getLoc();
  auto context = parentModule->getContext();
  auto create_callable_op = cast<mlir::quantum::CreateCallableOp>(op);  
  // Signature: void (%Tuple*, %Tuple*, %Tuple*)
  auto tuple_type =
      LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context));
  // typedef void (*CallableEntryType)(TuplePtr, TuplePtr, TuplePtr);
  // typedef void (*CaptureCallbackType)(TuplePtr, int32_t);
  auto callable_entry_ftype = LLVM::LLVMFunctionType::get(
      LLVM::LLVMVoidType::get(context),
      llvm::ArrayRef<Type>{tuple_type, tuple_type, tuple_type}, false);
  auto capture_callback_ftype = LLVM::LLVMFunctionType::get(
      LLVM::LLVMVoidType::get(context),
      llvm::ArrayRef<Type>{tuple_type,
                           IntegerType::get(rewriter.getContext(), 32)},
      false);
  FlatSymbolRefAttr symbol_ref =
      SymbolRefAttr::get(create_callable_op.functors(), context);
  // mlir::Value func_ptr = rewriter.create<LLVM::AddressOfOp>(
  //     location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref)
  auto callable_entry_fn_array_type = LLVM::LLVMArrayType::get(
      LLVM::LLVMPointerType::get(callable_entry_ftype), 4);
  auto callback_fn_array_type = LLVM::LLVMArrayType::get(
      LLVM::LLVMPointerType::get(capture_callback_ftype), 2);
  auto save_pt = rewriter.saveInsertionPoint();
  rewriter.setInsertionPointToStart(&parentModule.getRegion().getBlocks().front());
  const std::string functor_array_name =
      create_callable_op.functors().str() + "__Qops";
  auto fPtr_array = ArrayAttr::get({symbol_ref, symbol_ref, symbol_ref, symbol_ref} , context);
  auto fPtr_array_const_global = rewriter.create<LLVM::GlobalOp>(
      location, callable_entry_fn_array_type, /*isConstant=*/true, LLVM::Linkage::Internal,
      functor_array_name.c_str(), fPtr_array);
  rewriter.restoreInsertionPoint(save_pt);
  
  auto callable_return_type =
      LLVM::LLVMPointerType::get(get_quantum_type("Callable", context));
  FlatSymbolRefAttr qir_symbol_ref;
  if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_create_callable)) {
    qir_symbol_ref = SymbolRefAttr::get(qir_create_callable, context);
  } else {
    // Callable *
    // __quantum__rt__callable_create(Callable::CallableEntryType *ft,
    //                           Callable::CaptureCallbackType *callbacks,
    //                           TuplePtr capture)
    auto create_callable_ftype = LLVM::LLVMFunctionType::get(
        callable_return_type,
        llvm::ArrayRef<Type>{LLVM::LLVMPointerType::get(callable_entry_fn_array_type),
                             LLVM::LLVMPointerType::get(callback_fn_array_type),
                             tuple_type},
        false);

    // Insert the function declaration
    PatternRewriter::InsertionGuard insertGuard(rewriter);
    rewriter.setInsertionPointToStart(parentModule.getBody());
    rewriter.create<LLVM::LLVMFuncOp>(
        parentModule->getLoc(), qir_create_callable, create_callable_ftype);
    qir_symbol_ref = mlir::SymbolRefAttr::get(qir_create_callable, context);
  }
  
  mlir::Value callable_entry_nullPtr = rewriter.create<LLVM::NullOp>(
      location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type));
  mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>(
      location, LLVM::LLVMPointerType::get(callback_fn_array_type));
  mlir::Value tuple_nullPtr =
      rewriter.create<LLVM::NullOp>(location, tuple_type);
  auto createCallableCallOp = rewriter.create<mlir::CallOp>(
      location, qir_symbol_ref, callable_return_type,
      ArrayRef<Value>(
          {callable_entry_nullPtr, callbacks_nullPtr, tuple_nullPtr}));
  rewriter.replaceOp(op, createCallableCallOp.getResult(0));
  return success();
}
} // namespace qcor
 No newline at end of file
+2 −0
Original line number Diff line number Diff line
@@ -17,6 +17,8 @@ public:
class CreateCallableOpLowering : public ConversionPattern {
protected:
public:
  inline static const std::string qir_create_callable =
      "__quantum__rt__callable_create";
  explicit CreateCallableOpLowering(MLIRContext *context)
      : ConversionPattern(mlir::quantum::CreateCallableOp::getOperationName(),
                          1, context) {}