Commit 5abf128d authored by Benjamin Kramer's avatar Benjamin Kramer
Browse files

Add a pass that specializes parallel loops for easier unrolling and vectorization

This matches loops with a affine.min upper bound, limiting the trip
count to a constant, and rewrites them into two loops, one with constant
upper bound and one with variable upper bound. The assumption is that
the constant upper bound loop will be unrolled and vectorized, which is
preferable if this is the hot path.

Differential Revision: https://reviews.llvm.org/D75240
parent 586f13ae
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -23,6 +23,10 @@ class Pass;
/// Creates a loop fusion pass which fuses parallel loops.
std::unique_ptr<Pass> createParallelLoopFusionPass();

/// Creates a pass that specializes parallel loop for unrolling and
/// vectorization.
std::unique_ptr<Pass> createParallelLoopSpecializationPass();

/// Creates a pass which tiles innermost parallel loops.
std::unique_ptr<Pass>
createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {});
+1 −0
Original line number Diff line number Diff line
@@ -109,6 +109,7 @@ inline void registerAllPasses() {

  // LoopOps
  createParallelLoopFusionPass();
  createParallelLoopSpecializationPass();
  createParallelLoopTilingPass();

  // QuantOps
+1 −0
Original line number Diff line number Diff line
add_llvm_library(MLIRLoopOpsTransforms
  ParallelLoopFusion.cpp
  ParallelLoopSpecialization.cpp
  ParallelLoopTiling.cpp

  ADDITIONAL_HEADER_DIRS
+76 −0
Original line number Diff line number Diff line
//===- ParallelLoopSpecialization.cpp - loop.parallel specializeation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Specializes parallel loops for easier unrolling and vectorization.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/LoopOps/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using loop::ParallelOp;

/// Rewrite a loop with bounds defined by an affine.min with a constant into 2
/// loops after checking if the bounds are equal to that constant. This is
/// beneficial if the loop will almost always have the constant bound and that
/// version can be fully unrolled and vectorized.
static void specializeLoopForUnrolling(ParallelOp op) {
  SmallVector<int64_t, 2> constantIndices;
  constantIndices.reserve(op.upperBound().size());
  for (auto bound : op.upperBound()) {
    auto minOp = dyn_cast_or_null<AffineMinOp>(bound.getDefiningOp());
    if (!minOp)
      return;
    int64_t minConstant = std::numeric_limits<int64_t>::max();
    for (auto expr : minOp.map().getResults()) {
      if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
        minConstant = std::min(minConstant, constantIndex.getValue());
    }
    if (minConstant == std::numeric_limits<int64_t>::max())
      return;
    constantIndices.push_back(minConstant);
  }

  OpBuilder b(op);
  BlockAndValueMapping map;
  Value cond;
  for (auto bound : llvm::zip(op.upperBound(), constantIndices)) {
    Value constant = b.create<ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
    Value cmp = b.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq,
                                 std::get<0>(bound), constant);
    cond = cond ? b.create<AndOp>(op.getLoc(), cond, cmp) : cmp;
    map.map(std::get<0>(bound), constant);
  }
  auto ifOp = b.create<loop::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
  ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
  ifOp.getElseBodyBuilder().clone(*op.getOperation());
  op.erase();
}

namespace {
struct ParallelLoopSpecialization
    : public FunctionPass<ParallelLoopSpecialization> {
  void runOnFunction() override {
    getFunction().walk([](ParallelOp op) { specializeLoopForUnrolling(op); });
  }
};
} // namespace

std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
  return std::make_unique<ParallelLoopSpecialization>();
}

static PassRegistration<ParallelLoopSpecialization>
    pass("parallel-loop-specialization",
         "Specialize parallel loops for vectorization.");
+46 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -parallel-loop-specialization -split-input-file | FileCheck %s --dump-input-on-failure

#map0 = affine_map<()[s0, s1] -> (1024, s0 - s1)>
#map1 = affine_map<()[s0, s1] -> (64, s0 - s1)>

func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref<?x?xf32>, %B: memref<?x?xf32>,
                    %C: memref<?x?xf32>, %result: memref<?x?xf32>) {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %d0 = dim %A, 0 : memref<?x?xf32>
  %d1 = dim %A, 1 : memref<?x?xf32>
  %b0 = affine.min #map0()[%d0, %outer_i0]
  %b1 = affine.min #map1()[%d1, %outer_i1]
  loop.parallel (%i0, %i1) = (%c0, %c0) to (%b0, %b1) step (%c1, %c1) {
    %B_elem = load %B[%i0, %i1] : memref<?x?xf32>
    %C_elem = load %C[%i0, %i1] : memref<?x?xf32>
    %sum_elem = addf %B_elem, %C_elem : f32
    store %sum_elem, %result[%i0, %i1] : memref<?x?xf32>
  }
  return
}

// CHECK-LABEL:   func @parallel_loop(
// CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: memref<?x?xf32>, [[VAL_3:%.*]]: memref<?x?xf32>, [[VAL_4:%.*]]: memref<?x?xf32>, [[VAL_5:%.*]]: memref<?x?xf32>) {
// CHECK:           [[VAL_6:%.*]] = constant 0 : index
// CHECK:           [[VAL_7:%.*]] = constant 1 : index
// CHECK:           [[VAL_8:%.*]] = dim [[VAL_2]], 0 : memref<?x?xf32>
// CHECK:           [[VAL_9:%.*]] = dim [[VAL_2]], 1 : memref<?x?xf32>
// CHECK:           [[VAL_10:%.*]] = affine.min #map0(){{\[}}[[VAL_8]], [[VAL_0]]]
// CHECK:           [[VAL_11:%.*]] = affine.min #map1(){{\[}}[[VAL_9]], [[VAL_1]]]
// CHECK:           [[VAL_12:%.*]] = constant 1024 : index
// CHECK:           [[VAL_13:%.*]] = cmpi "eq", [[VAL_10]], [[VAL_12]] : index
// CHECK:           [[VAL_14:%.*]] = constant 64 : index
// CHECK:           [[VAL_15:%.*]] = cmpi "eq", [[VAL_11]], [[VAL_14]] : index
// CHECK:           [[VAL_16:%.*]] = and [[VAL_13]], [[VAL_15]] : i1
// CHECK:           loop.if [[VAL_16]] {
// CHECK:             loop.parallel ([[VAL_17:%.*]], [[VAL_18:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_12]], [[VAL_14]]) step ([[VAL_7]], [[VAL_7]]) {
// CHECK:               store
// CHECK:             }
// CHECK:           } else {
// CHECK:             loop.parallel ([[VAL_22:%.*]], [[VAL_23:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_10]], [[VAL_11]]) step ([[VAL_7]], [[VAL_7]]) {
// CHECK:               store
// CHECK:             }
// CHECK:           }
// CHECK:           return
// CHECK:         }