Commit dea33c80 authored by Tom Eccles's avatar Tom Eccles
Browse files

[mlir][Transforms] teach CSE about recursive memory effects

Add support for reasoning about operations with recursive memory effects
to CSE. The recursive effects are gathered by a helper function. I
decided to allow returning duplicates from the helper function because
there's no benefit to spending the computation time to remove them in
the existing use case.

Differential Revision: https://reviews.llvm.org/D156805
parent e6d5dcf8
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -332,6 +332,17 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
/// std::nullopt indicates that an option did not have a memory effect interface
/// and so no result could be obtained. An empty vector indicates that there
/// were no memory effects found (but every operation implemented the memory
/// effect interface or has RecursiveMemoryEffects). If the vector contains
/// multiple effects, these effects may be duplicates.
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
getEffectsRecursively(Operation *rootOp);

/// Returns true if the given operation is speculatable, i.e. has no undefined
/// behavior or other side effects.
///
+33 −0
Original line number Diff line number Diff line
@@ -182,6 +182,39 @@ bool mlir::isMemoryEffectFree(Operation *op) {
  return true;
}

// the returned vector may contain duplicate effects
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
mlir::getEffectsRecursively(Operation *rootOp) {
  SmallVector<MemoryEffects::EffectInstance> effects;
  SmallVector<Operation *> effectingOps(1, rootOp);
  while (!effectingOps.empty()) {
    Operation *op = effectingOps.pop_back_val();

    // If the operation has recursive effects, push all of the nested
    // operations on to the stack to consider.
    bool hasRecursiveEffects =
        op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
    if (hasRecursiveEffects) {
      for (Region &region : op->getRegions()) {
        for (Block &block : region) {
          for (Operation &nestedOp : block) {
            effectingOps.push_back(&nestedOp);
          }
        }
      }
    }

    if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
      effectInterface.getEffects(effects);
    } else if (!hasRecursiveEffects) {
      // the operation does not have recursive memory effects or implement
      // the memory effect op interface. Its effects are unknown.
      return std::nullopt;
    }
  }
  return effects;
}

bool mlir::isSpeculatable(Operation *op) {
  auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
  if (!conditionallySpeculatable)
+13 −7
Original line number Diff line number Diff line
@@ -199,17 +199,23 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
    }
  }
  while (nextOp && nextOp != toOp) {
    auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
    std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
        getEffectsRecursively(nextOp);
    if (!effects) {
      // TODO: Do we need to handle other effects generically?
      // If the operation does not implement the MemoryEffectOpInterface we
    // conservatively assumes it writes.
    if ((nextOpMemEffects &&
         nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
        !nextOpMemEffects) {
      // conservatively assume it writes.
      result.first->second =
          std::make_pair(nextOp, MemoryEffects::Write::get());
      return true;
    }

    for (const MemoryEffects::EffectInstance &effect : *effects) {
      if (isa<MemoryEffects::Write>(effect.getEffect())) {
        result.first->second = {nextOp, MemoryEffects::Write::get()};
        return true;
      }
    }
    nextOp = nextOp->getNextNode();
  }
  result.first->second = std::make_pair(toOp, nullptr);
+2 −3
Original line number Diff line number Diff line
@@ -332,8 +332,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK:               scf.yield %[[VAL_145]]
// CHECK:             }
// CHECK:             %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
// CHECK:             %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]]
// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
// CHECK:             %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
// CHECK:             %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
// CHECK:               %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
+1 −3
Original line number Diff line number Diff line
@@ -142,9 +142,7 @@
// CHECK:                   scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
// CHECK:                 }
// CHECK:                 %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
// CHECK:                 %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK:                 %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
// CHECK:                 memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK:                 memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK:                 scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
// CHECK:               }
// CHECK:               %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
Loading