Commit 6e5ad0ef authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixing tuple unpack lowering



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 4bc90777
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -129,7 +129,10 @@ int main(int argc, char **argv) {
  // Now lower MLIR to LLVM IR
  llvm::LLVMContext llvmContext;
  auto llvmModule = mlir::translateModuleToLLVMIR(*module, llvmContext);

  if (!llvmModule) {
    llvm::errs() << "Failed to emit LLVM IR\n";
    return -1;
  }
  // Optimize the LLVM IR
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
+29 −27
Original line number Diff line number Diff line
@@ -30,52 +30,54 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite(
  assert(operands.size() == 1);
  ModuleOp parentModule = op->getParentOfType<ModuleOp>();
  auto context = parentModule->getContext();
  std::cout << "Before:\n";
  parentModule.dump();
  // Cast the tuple to a struct type:
  auto tuple_unpack_op = cast<mlir::quantum::TupleUnpackOp>(op);
  std::vector<Type> unpacked_type_list;
  std::vector<Type> tuple_struct_type_list;
  mlir::SmallVector<mlir::Type> tuple_struct_type_list;
  for (const auto &result : tuple_unpack_op.result()) {
    if (result.getType().isa<mlir::OpaqueType>() &&
        result.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") {
      auto array_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(array_type));
      tuple_struct_type_list.emplace_back(array_type);
      tuple_struct_type_list.push_back(array_type);
    } else if (result.getType().isa<mlir::FloatType>()) {
      auto float_type = mlir::FloatType::getF64(context);
      unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(float_type));
      tuple_struct_type_list.emplace_back(float_type);
      tuple_struct_type_list.push_back(float_type);
    }
  }

  auto unpacked_type = LLVM::LLVMStructType::getLiteral(
      context, llvm::ArrayRef<Type>(tuple_struct_type_list));
  auto unpacked_struct_type =
      LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list);
  auto location = parentModule->getLoc();
  auto bitcast = rewriter.create<LLVM::BitcastOp>(
      location, LLVM::LLVMPointerType::get(unpacked_type), operands[0]);
  auto structPtr =
      rewriter
          .create<LLVM::BitcastOp>(
              location, LLVM::LLVMPointerType::get(unpacked_struct_type),
              operands[0])
          .res();

  std::vector<mlir::Value> unpacked_vals;
  for (size_t idx = 0; idx < unpacked_type_list.size(); ++idx) {
  mlir::SmallVector<mlir::Value> unpacked_vals;
  mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>(
      location, IntegerType::get(rewriter.getContext(), 64),
      rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
  for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) {
    mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>(
        location, IntegerType::get(rewriter.getContext(), 64),
        rewriter.getIntegerAttr(rewriter.getIndexType(), idx));
    auto getelementptr = rewriter.create<LLVM::GEPOp>(
        location, unpacked_type_list[idx], bitcast,
        idx_cst);
    auto field_ptr =
        rewriter
            .create<LLVM::GEPOp>(
                location,
                LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]),
                structPtr, mlir::ArrayRef<mlir::Value>({zero_cst, idx_cst}))
            .res();
    auto load_op = rewriter.create<LLVM::LoadOp>(
        location, tuple_struct_type_list[idx], getelementptr.res());
    unpacked_vals.emplace_back(load_op.res());
        location, tuple_struct_type_list[idx], field_ptr);
    unpacked_vals.push_back(load_op.res());
  }

  for (size_t idx = 0; idx < unpacked_vals.size(); ++idx) {
    mlir::Value unpack_result = *(std::next(tuple_unpack_op.result_begin(), idx));
    unpack_result.replaceAllUsesWith(unpacked_vals[idx]);
  }
  rewriter.eraseOp(op);
  std::cout << "After:\n";
  parentModule.dump();
  rewriter.replaceOp(op, unpacked_vals);
  // std::cout << "After:\n";
  // parentModule.dump();
  return success();
}