Commit ffdbeccc authored by Matthias Springer's avatar Matthias Springer
Browse files

[mlir][bufferization] Add bufferization.alloc_tensor op

This change adds a new op `alloc_tensor` to the bufferization dialect. During bufferization, this op is always lowered to a buffer allocation (unless it is "eliminated" by a pre-processing pass). It is useful to have such an op in tensor land, because it allows users to model tensor SSA use-def chains (which drive bufferization decisions) and because tensor SSA use-def chains can be analyzed by One-Shot Bufferize, while memref values cannot.

This change also replaces all uses of linalg.init_tensor in bufferization-related code with bufferization.alloc_tensor.

linalg.init_tensor and bufferization.alloc_tensor are similar, but the purpose of the former one is just to carry a shape. It does not indicate a memory allocation.

linalg.init_tensor is not suitable for modelling SSA use-def chains for bufferization purposes, because linalg.init_tensor is marked as not having side effects (in contrast to alloc_tensor). As such, it is legal to move linalg.init_tensor ops around/CSE them/etc. This is not desirable for alloc_tensor; it represents an explicit buffer allocation while still in tensor land and such allocations should not suddenly disappear or get moved around when running the canonicalizer/CSE/etc.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC

Differential Revision: https://reviews.llvm.org/D126003
parent 4f6ac969
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

//===----------------------------------------------------------------------===//
// Bufferization Dialect
+3 −1
Original line number Diff line number Diff line
@@ -25,7 +25,9 @@ def Bufferization_Dialect : Dialect {
    found in [bufferization](/docs/Bufferization/) and [buffer
    deallocation](/docs/BufferDeallocationInternals/).
  }];
  let dependentDialects = ["memref::MemRefDialect", "tensor::TensorDialect"];
  let dependentDialects = [
    "AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
  ];

  let extraClassDeclaration = [{
    /// An attribute that can override writability of buffers of tensor function
+116 −0
Original line number Diff line number Diff line
@@ -12,12 +12,128 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"

class Bufferization_Op<string mnemonic, list<Trait> traits = []>
    : Op<Bufferization_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// AllocTensorOp
//===----------------------------------------------------------------------===//

def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
    [BufferizableOpInterface,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
  let summary = "buffer allocation in tensor land";

  let description = [{
    `bufferization.alloc_tensor` is an operation that bufferizes to a buffer
    allocation of a given shape. The shape could be dynamic or static.
    Reading from the result of an `alloc_tensor` op yields an undefined value.

    `alloc_tensor` is a helper op for bufferization. It marks the beginning of
    a new tensor SSA use-def chain and is used to control in-place bufferization
    decisions during One-Shot Bufferize.
  }];

  let arguments =
    (ins Variadic<Index>:$sizes, I64ArrayAttr:$static_sizes);

  let results = (outs AnyTensor:$result);

  let assemblyFormat = [{
    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
    `:` type($result)
  }];

  let extraClassDeclaration = [{
    LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);

    bool isMemoryWrite(OpResult opResult, const AnalysisState &state) const {
      // AllocTensorOps allocate but do not write.
      return false;
    }

    static StringRef getStaticSizesAttrName() {
      return "static_sizes";
    }

    RankedTensorType getType() {
      return getResult().getType().cast<RankedTensorType>();
    }

    // Infer the shape of the result tensor given the static shapes
    // and element type of the result tensor.
    static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
                                Attribute encoding = {});

    // Return true if the size of the tensor is dynamic at `idx`
    bool isDynamicSize(unsigned idx) {
      APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
      return ShapedType::isDynamic(v.getSExtValue());
    }

    // Assert that the size of the result tensor is static at `idx`
    // and return the shape.
    int64_t getStaticSize(unsigned idx) {
      assert(!isDynamicSize(idx) && "expected static size");
      APInt v = *(static_sizes().
          template getAsValueRange<IntegerAttr>().begin() + idx);
        return v.getSExtValue();
    }

    // Return the argument position that contains the dynamic size of
    // the tensor at dimension `idx`. Asserts that the shape is
    // dynamic at that `idx`.
    unsigned getIndexOfDynamicSize(unsigned idx) {
      assert(isDynamicSize(idx) && "expected dynamic size");
      return std::count_if(
          static_sizes().getValue().begin(),
          static_sizes().getValue().begin() + idx,
          [&](Attribute attr) {
            return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
          });
    }

    // Return both static and dynamic sizes as a list of `OpFoldResult`.
    SmallVector<OpFoldResult> getMixedSizes();

    // Return the Value of the dynamic size of the tensor at dimension
    // `idx`. Asserts that the shape is dynamic at that `idx.
    Value getDynamicSize(unsigned idx) {
      return getOperand(getIndexOfDynamicSize(idx));
    }
  }];

  let builders = [
    OpBuilder<(ins "ValueRange":$shape,
                  "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
    [{
      build($_builder, $_state,
            AllocTensorOp::inferResultType(staticShape, elementType),
            shape, $_builder.getI64ArrayAttr(staticShape));
    }]>,
    OpBuilder<(ins "ValueRange":$shape, "Type":$elementType),
    [{
      SmallVector<int64_t, 4> staticShape(
        shape.size(), ShapedType::kDynamicSize);
      build($_builder, $_state, shape, staticShape, elementType);
    }]>,
    OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
    [{
      build($_builder, $_state, ValueRange{}, staticShape, elementType);
    }]>,
    OpBuilder<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType,
      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
  ];

  let hasCanonicalizer = 1;
  let hasCustomAssemblyFormat = 1;
  let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// CloneOp
//===----------------------------------------------------------------------===//
+48 −0
Original line number Diff line number Diff line
//===- AllocTensorElimination.h - alloc_tensor op elimination -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H

#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"

namespace mlir {
namespace bufferization {

/// A function that matches anchor OpOperands for AllocTensorOp elimination.
/// If an OpOperand is matched, the function should populate the SmallVector
/// with all values that are needed during `RewriteFn` to produce the
/// replacement value.
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;

/// A function that rewrites matched anchors.
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;

/// Try to eliminate AllocTensorOps inside `op`.
///
/// * `rewriteFunc` generates the replacement for the AllocTensorOp.
/// * Only AllocTensorOps that are anchored on a matching OpOperand as per
///   `anchorMatchFunc` are considered. "Anchored" means that there is a path
///   on the reverse SSA use-def chain, starting from the OpOperand and always
///   following the aliasing  OpOperand, that eventually ends at a single
///   AllocTensorOp.
LogicalResult eliminateAllocTensors(RewriterBase &rewriter, Operation *op,
                                    bufferization::AnalysisState &state,
                                    AnchorMatchFn anchorMatchFunc,
                                    RewriteFn rewriteFunc);

/// Try to eliminate AllocTensorOps inside `op` that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredAllocTensorEliminationStep(
    RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);

} // namespace bufferization
} // namespace mlir

#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
+7 −0
Original line number Diff line number Diff line
@@ -64,6 +64,13 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
std::unique_ptr<Pass>
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);

/// Create a pass that tries to eliminate alloc_tensor ops that are anchored on
/// insert_slice ops.
std::unique_ptr<Pass> createAllocTensorEliminationPass();

/// Create a pass that bufferizes ops from the bufferization dialect.
std::unique_ptr<Pass> createBufferizationBufferizePass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Loading