Commit 1bc16c06 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Properly handle inlining of modifier regions into the main block



Simple cloning is not a proper way to do this, causing dangling references...

Also, we need to do this lowering as a two-step procedure to prevent race condition in the dialect conversion pipeline.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 62a6f744
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -476,6 +476,19 @@ for val in [0:8] {
  // EXPECT_FALSE(qcor::execute(check_ccx, "ccx"));
}

TEST(qasm3VisitorTester, checkNestedModifier) {
  const std::string check_nested = R"#(OPENQASM 3;
qubit q;
qubit qq;
inv @ pow(2) @ t q;
pow(2) @ inv @ t q;
ctrl @ pow(2) @ t q, qq;
pow(2) @ ctrl @ t q, qq;
)#";
  auto llvm = qcor::mlir_compile(check_nested, "nested",
                                 qcor::OutputType::LLVMIR, false, 0);
  std::cout << "LLVM:\n" << llvm << "\n";
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
+3 −0
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@ const std::string mlir_compile(const std::string &src,
  if (opt_level > 0) {
    qcor::configureOptimizationPasses(pm);
  }
  pm.addPass(std::make_unique<qcor::ModifierRegionRewritePass>());
  pm.addPass(std::make_unique<qcor::QuantumToLLVMLoweringPass>(
      true, unique_function_names));
  auto module_op = (*module).getOperation();
@@ -116,6 +117,7 @@ int execute(const std::string &src, const std::string &kernel_name,
  // Create the PassManager for lowering to LLVM MLIR and run it
  mlir::PassManager pm(&context);
  qcor::configureOptimizationPasses(pm);
  pm.addPass(std::make_unique<qcor::ModifierRegionRewritePass>());
  pm.addPass(std::make_unique<qcor::QuantumToLLVMLoweringPass>(
      true, unique_function_names));
  auto module_op = (*module).getOperation();
@@ -179,6 +181,7 @@ int execute(const std::string &src, const std::string &kernel_name,
  // Create the PassManager for lowering to LLVM MLIR and run it
  mlir::PassManager pm(&context);
  qcor::configureOptimizationPasses(pm);
  pm.addPass(std::make_unique<qcor::ModifierRegionRewritePass>());
  pm.addPass(std::make_unique<qcor::QuantumToLLVMLoweringPass>(
      true, unique_function_names));
  auto module_op = (*module).getOperation();
+1 −0
Original line number Diff line number Diff line
@@ -173,6 +173,7 @@ int main(int argc, char **argv) {
  }

  // Lower MLIR to LLVM
  pm.addPass(std::make_unique<qcor::ModifierRegionRewritePass>());
  pm.addPass(std::make_unique<qcor::QuantumToLLVMLoweringPass>(
      qoptimizations, unique_function_names));

+86 −109
Original line number Diff line number Diff line
@@ -24,6 +24,35 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"

namespace {
// Inline a region into a location specified by an Op
void inlineRegion(mlir::Region *regionToInline,
                  mlir::Operation *inlineLocation) {
  mlir::Block *insertBlock = inlineLocation->getBlock();
  assert(insertBlock);
  mlir::Region *insertRegion = insertBlock->getParent();
  // Split the insertion block.
  mlir::Block *postInsertBlock =
      insertBlock->splitBlock(inlineLocation->getIterator());
  mlir::BlockAndValueMapping mapper;
  regionToInline->cloneInto(insertRegion, postInsertBlock->getIterator(),
                            mapper);
  auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
                                    postInsertBlock->getIterator());
  mlir::Block *firstNewBlock = &*newBlocks.begin();

  auto *firstBlockTerminator = firstNewBlock->getTerminator();
  firstBlockTerminator->erase();
  // Merge the post insert block into the cloned entry block.
  firstNewBlock->getOperations().splice(firstNewBlock->end(),
                                        postInsertBlock->getOperations());
  postInsertBlock->erase();
  // Splice the instructions of the inlined entry block into the insert block.
  insertBlock->getOperations().splice(insertBlock->end(),
                                      firstNewBlock->getOperations());
  firstNewBlock->erase();
}
} // namespace
namespace qcor {
// Note: the Modifier regions implement quantum dataflow analysis (value
// semantics) by returning new mlir::Value of all qubit operands (to be
@@ -41,42 +70,6 @@ LogicalResult PowURegionOpLowering::matchAndRewrite(
  ModuleOp parentModule = op->getParentOfType<ModuleOp>();
  auto context = parentModule->getContext();
  auto location = parentModule->getLoc();
  // Start
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
      static const std::string qir_start_func =
          "__quantum__rt__start_pow_u_region";
      if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_start_func)) {
        return SymbolRefAttr::get(qir_start_func, context);
      } else {

        // prototype should be (int64_t) -> void :

        auto void_type = LLVM::LLVMVoidType::get(context);

        auto func_type = LLVM::LLVMFunctionType::get(
            void_type, llvm::ArrayRef<Type>{}, false);

        PatternRewriter::InsertionGuard insertGuard(rewriter);
        rewriter.setInsertionPointToStart(parentModule.getBody());
        rewriter.create<LLVM::LLVMFuncOp>(location, qir_start_func, func_type);

        return mlir::SymbolRefAttr::get(qir_start_func, context);
      }
    }();

    rewriter.create<mlir::CallOp>(
        location, qir_get_fn_ptr, LLVM::LLVMVoidType::get(context),
        llvm::makeArrayRef(std::vector<mlir::Value>{}));
  }

  // Inline the body into the current scope:
  auto casted = cast<mlir::quantum::PowURegion>(op);
  mlir::Block &powBlock = casted.body().getBlocks().front();
  for (auto &subOp : powBlock.getOperations()) {
    rewriter.insert(subOp.clone());
  }

  // End
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
@@ -107,6 +100,7 @@ LogicalResult PowURegionOpLowering::matchAndRewrite(
  }

  {
    auto casted = cast<mlir::quantum::PowURegion>(op);
    mlir::SmallVector<mlir::Value> chained_values;
    for (const auto &targetQubit : casted.qubits()) {
      chained_values.push_back(targetQubit);
@@ -124,43 +118,6 @@ LogicalResult CtrlURegionOpLowering::matchAndRewrite(
  ModuleOp parentModule = op->getParentOfType<ModuleOp>();
  auto context = parentModule->getContext();
  auto location = parentModule->getLoc();

  // Start
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
      static const std::string qir_start_func =
          "__quantum__rt__start_ctrl_u_region";
      if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_start_func)) {
        return SymbolRefAttr::get(qir_start_func, context);
      } else {
        // prototype should be () -> void :
        // ret is void
        auto void_type = LLVM::LLVMVoidType::get(context);

        auto func_type = LLVM::LLVMFunctionType::get(
            void_type, llvm::ArrayRef<Type>{}, false);

        PatternRewriter::InsertionGuard insertGuard(rewriter);
        rewriter.setInsertionPointToStart(parentModule.getBody());
        rewriter.create<LLVM::LLVMFuncOp>(location, qir_start_func, func_type);

        return mlir::SymbolRefAttr::get(qir_start_func, context);
      }
    }();

    rewriter.create<mlir::CallOp>(
        location, qir_get_fn_ptr, LLVM::LLVMVoidType::get(context),
        llvm::makeArrayRef(std::vector<mlir::Value>{}));
  }

  // Inline
  auto casted = cast<mlir::quantum::CtrlURegion>(op);
  {
    mlir::Block &ctrlBlock = casted.body().getBlocks().front();
    for (auto &subOp : ctrlBlock.getOperations()) {
      rewriter.insert(subOp.clone());
    }
  }
  // End
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
@@ -199,6 +156,7 @@ LogicalResult CtrlURegionOpLowering::matchAndRewrite(
  }

  {
    auto casted = cast<mlir::quantum::CtrlURegion>(op);
    mlir::SmallVector<mlir::Value> chained_values{casted.ctrl_qubit()};
    for (const auto &targetQubit : casted.qubits()) {
      chained_values.push_back(targetQubit);
@@ -215,42 +173,6 @@ LogicalResult AdjURegionOpLowering::matchAndRewrite(
  ModuleOp parentModule = op->getParentOfType<ModuleOp>();
  auto context = parentModule->getContext();
  auto location = parentModule->getLoc();

  // Start
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
      static const std::string qir_start_func =
          "__quantum__rt__start_adj_u_region";
      if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_start_func)) {
        return SymbolRefAttr::get(qir_start_func, context);
      } else {
        // prototype should be () -> void :
        auto void_type = LLVM::LLVMVoidType::get(context);

        auto func_type = LLVM::LLVMFunctionType::get(
            void_type, llvm::ArrayRef<Type>{}, false);

        PatternRewriter::InsertionGuard insertGuard(rewriter);
        rewriter.setInsertionPointToStart(parentModule.getBody());
        rewriter.create<LLVM::LLVMFuncOp>(location, qir_start_func, func_type);

        return mlir::SymbolRefAttr::get(qir_start_func, context);
      }
    }();

    rewriter.create<mlir::CallOp>(
        location, qir_get_fn_ptr, LLVM::LLVMVoidType::get(context),
        llvm::makeArrayRef(std::vector<mlir::Value>{}));
  }

  // Inline
  {
    auto casted = cast<mlir::quantum::AdjURegion>(op);
    mlir::Block &adjBlock = casted.body().getBlocks().front();
    for (auto &subOp : adjBlock.getOperations()) {
      rewriter.insert(subOp.clone());
    }
  }
  // End
  {
    FlatSymbolRefAttr qir_get_fn_ptr = [&]() {
@@ -297,4 +219,59 @@ LogicalResult EndModifierRegionOpLowering::matchAndRewrite(
  rewriter.eraseOp(op);
  return success();
}

void ModifierRegionRewritePass::getDependentDialects(
    DialectRegistry &registry) const {
  registry.insert<LLVM::LLVMDialect>();
}

void ModifierRegionRewritePass::runOnOperation() {
  const auto insertStartCall = [](const std::string &qir_start_func,
                                  mlir::OpBuilder &opBuilder,
                                  mlir::ModuleOp &parentModule) {
    mlir::FlatSymbolRefAttr startModifiedU = [&]() {
      PatternRewriter::InsertionGuard insertGuard(opBuilder);
      opBuilder.setInsertionPointToStart(
          &parentModule.getRegion().getBlocks().front());
      if (parentModule.lookupSymbol<mlir::FuncOp>(qir_start_func)) {
        auto fnNameAttr = opBuilder.getSymbolRefAttr(qir_start_func);
        return fnNameAttr;
      }

      auto func_decl = opBuilder.create<mlir::FuncOp>(
          opBuilder.getUnknownLoc(), qir_start_func,
          opBuilder.getFunctionType(llvm::None, llvm::None));
      func_decl.setVisibility(mlir::SymbolTable::Visibility::Private);
      return mlir::SymbolRefAttr::get(qir_start_func,
                                      parentModule->getContext());
    }();

    opBuilder.create<mlir::CallOp>(opBuilder.getUnknownLoc(), startModifiedU,
                                   llvm::None, llvm::None);
  };

  getOperation().walk([&](mlir::quantum::AdjURegion op) {
    mlir::OpBuilder rewriter(op);
    mlir::ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    insertStartCall("__quantum__rt__start_adj_u_region", rewriter,
                    parentModule);
    inlineRegion(&op.body(), op.getOperation());
  });

  getOperation().walk([&](mlir::quantum::CtrlURegion op) {
    mlir::OpBuilder rewriter(op);
    mlir::ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    insertStartCall("__quantum__rt__start_ctrl_u_region", rewriter,
                    parentModule);
    inlineRegion(&op.body(), op.getOperation());
  });

  getOperation().walk([&](mlir::quantum::PowURegion op) {
    mlir::OpBuilder rewriter(op);
    mlir::ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    insertStartCall("__quantum__rt__start_pow_u_region", rewriter,
                    parentModule);
    inlineRegion(&op.body(), op.getOperation());
  });
}
} // namespace qcor
 No newline at end of file
+7 −0
Original line number Diff line number Diff line
@@ -45,4 +45,11 @@ public:
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
};

struct ModifierRegionRewritePass
    : public PassWrapper<ModifierRegionRewritePass, OperationPass<ModuleOp>> {
  void getDependentDialects(DialectRegistry &registry) const override;
  void runOnOperation() final;
  ModifierRegionRewritePass() {}
};
} // namespace qcor
 No newline at end of file
Loading