Commit ad18d18f authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Make the pow->loop->unroll optimization to work



Seems like the i64->index casting makes the loop unroll not working.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 67b9ca85
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -104,7 +104,10 @@ def PowURegion : QuantumOp<"pow_u_region", [
          repeated a set number of times.}];
  // Rationale: we wrap modifier block as a proper value-semantics op.
  // i.e., forwarding SSA vars at input and output.
  let arguments = (ins AnyI64:$pow, Variadic<QubitType>:$qubits);
  // Note: we use pow of Index type to be compatible with loop bound types.
  // i.e., if optimization enabled, any constant `pow` values can be propagated
  // to SCF/Affine loops w/o the need for casting (prevent loop unrolling)
  let arguments = (ins Index:$pow, Variadic<QubitType>:$qubits);
  let results = (outs Variadic<QubitType>:$result);
  let regions = (region SizedRegion<1>:$body);
  let skipDefaultBuilders = 1;
+6 −1
Original line number Diff line number Diff line
@@ -399,10 +399,15 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
                                                        builder.getI64Type());
          }
        }
        if (!power.getType().isIndex()) {
          power = builder.create<mlir::IndexCastOp>(
              location, builder.getIndexType(), power);
        }
      } else {
        auto p = symbol_table.evaluate_constant_integer_expression(
            m->expression()->getText());
        auto pow_attr = mlir::IntegerAttr::get(builder.getI64Type(), p);
        assert(p >= 0);
        auto pow_attr = mlir::IntegerAttr::get(builder.getIndexType(), p);
        power = builder.create<mlir::ConstantOp>(location, pow_attr);
      }
      auto powerUOp = builder.create<mlir::quantum::PowURegion>(location, power,
+10 −18
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -27,22 +28,18 @@ void ModifierBlockInlinerPass::handlePowU() {
    // Must be a single-block op
    assert(op.body().getBlocks().size() == 1);
    mlir::OpBuilder rewriter(op);
    mlir::Value powVal = [&]() -> mlir::Value {
      if (op.pow().getType().isIndex()) {
        return op.pow();
      }
      return rewriter.create<mlir::IndexCastOp>(
          op.getLoc(), rewriter.getIndexType(), op.pow());
    }();
    assert(op.pow().getType().isIndex());
    mlir::Value powVal = op.pow();
    mlir::Value lbs_val = rewriter.create<mlir::ConstantOp>(
        op.getLoc(), mlir::IntegerAttr::get(rewriter.getIndexType(), 0));
    mlir::Value step_val = rewriter.create<mlir::ConstantOp>(
        op.getLoc(), mlir::IntegerAttr::get(rewriter.getIndexType(), 1));
    mlir::Block &powBlock = op.body().getBlocks().front();
    // Convert the pow modifier to a For loop,
    // which might be unrolled if possible (constant-value loop bound)
    auto forOp = rewriter.create<mlir::scf::ForOp>(
        op.getLoc(), lbs_val, powVal, step_val, op.qubits(),
    mlir::ValueRange lbs(lbs_val);
    mlir::ValueRange ubs(powVal);
    auto forOp = rewriter.create<mlir::AffineForOp>(
        op.getLoc(), lbs, rewriter.getMultiDimIdentityMap(lbs.size()), ubs,
        rewriter.getMultiDimIdentityMap(ubs.size()), 1, llvm::None,
        [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
            mlir::Value iv, mlir::ValueRange itrArgs) {
          mlir::OpBuilder::InsertionGuard guard(nestedBuilder);
@@ -52,18 +49,13 @@ void ModifierBlockInlinerPass::handlePowU() {
            if (auto terminator =
                    mlir::dyn_cast_or_null<mlir::quantum::ModifierEndOp>(
                        newOp)) {
              nestedBuilder.create<mlir::scf::YieldOp>(nestedLoc,
                                                       terminator.qubits());
              nestedBuilder.create<mlir::AffineYieldOp>(nestedLoc);
              newOp->erase();
              break;
            }
          }
        });

    assert(forOp.results().size() == op.result().size());
    for (size_t i = 0; i < op.result().size(); ++i) {
      op.result()[i].replaceAllUsesWith(forOp.results()[i]);
    }
    op.body().getBlocks().clear();
    deadOps.emplace_back(op.getOperation());
    // ModuleOp parentModule = op->getParentOfType<ModuleOp>();