Commit e2b71610 authored by Rahul Joshi's avatar Rahul Joshi
Browse files

[MLIR] Add argument related API to Region

- Arguments of the first block of a region are considered region arguments.
- Add API on Region class to deal with these arguments directly instead of
  using the front() block.
- Changed several instances of existing code that can use this API
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46535

Differential Revision: https://reviews.llvm.org/D83599
parent 85bed2f3
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -237,7 +237,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
    /// the workgroup memory
    ArrayRef<BlockArgument> getWorkgroupAttributions() {
      auto begin =
          std::next(getBody().front().args_begin(), getType().getNumInputs());
          std::next(getBody().args_begin(), getType().getNumInputs());
      auto end = std::next(begin, getNumWorkgroupAttributions());
      return {begin, end};
    }
@@ -248,7 +248,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,

    /// Returns the number of buffers located in the private memory.
    unsigned getNumPrivateAttributions() {
      return getBody().front().getNumArguments() - getType().getNumInputs() -
      return getBody().getNumArguments() - getType().getNumInputs() -
          getNumWorkgroupAttributions();
    }
 
@@ -258,9 +258,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
      // Buffers on the private memory always come after buffers on the workgroup
      // memory.
      auto begin =
          std::next(getBody().front().args_begin(),
          std::next(getBody().args_begin(),
                    getType().getNumInputs() + getNumWorkgroupAttributions());
      return {begin, getBody().front().args_end()};
      return {begin, getBody().args_end()};
    }

    /// Adds a new block argument that corresponds to buffers located in
+1 −1
Original line number Diff line number Diff line
@@ -583,7 +583,7 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
  let extraClassDeclaration = [{
    // The value stored in memref[ivs].
    Value getCurrentValue() {
      return body().front().getArgument(0);
      return body().getArgument(0);
    }
    MemRefType getMemRefType() {
      return memref().getType().cast<MemRefType>();
+5 −7
Original line number Diff line number Diff line
@@ -216,15 +216,13 @@ public:
  }

  /// Gets argument.
  BlockArgument getArgument(unsigned idx) {
    return getBlocks().front().getArgument(idx);
  }
  BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); }

  /// Support argument iteration.
  using args_iterator = Block::args_iterator;
  args_iterator args_begin() { return front().args_begin(); }
  args_iterator args_end() { return front().args_end(); }
  Block::BlockArgListType getArguments() { return front().getArguments(); }
  using args_iterator = Region::args_iterator;
  args_iterator args_begin() { return getBody().args_begin(); }
  args_iterator args_end() { return getBody().args_end(); }
  Block::BlockArgListType getArguments() { return getBody().getArguments(); }

  //===--------------------------------------------------------------------===//
  // Argument Attributes
+45 −0
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@
#include "mlir/IR/Block.h"

namespace mlir {
class TypeRange;
template <typename ValueRangeT>
class ValueTypeRange;
class BlockAndValueMapping;

/// This class contains a list of basic blocks and a link to the parent
@@ -62,6 +65,48 @@ public:
    return &Region::blocks;
  }

  //===--------------------------------------------------------------------===//
  // Argument Handling
  //===--------------------------------------------------------------------===//

  // This is the list of arguments to the block.
  using BlockArgListType = MutableArrayRef<BlockArgument>;
  BlockArgListType getArguments() {
    return empty() ? BlockArgListType() : front().getArguments();
  }
  using args_iterator = BlockArgListType::iterator;
  using reverse_args_iterator = BlockArgListType::reverse_iterator;
  args_iterator args_begin() { return getArguments().begin(); }
  args_iterator args_end() { return getArguments().end(); }
  reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
  reverse_args_iterator args_rend() { return getArguments().rend(); }

  bool args_empty() { return getArguments().empty(); }

  /// Add one value to the argument list.
  BlockArgument addArgument(Type type) { return front().addArgument(type); }

  /// Insert one value to the position in the argument list indicated by the
  /// given iterator. The existing arguments are shifted. The block is expected
  /// not to have predecessors.
  BlockArgument insertArgument(args_iterator it, Type type) {
    return front().insertArgument(it, type);
  }

  /// Add one argument to the argument list for each type specified in the list.
  iterator_range<args_iterator> addArguments(TypeRange types);

  /// Add one value to the argument list at the specified position.
  BlockArgument insertArgument(unsigned index, Type type) {
    return front().insertArgument(index, type);
  }

  /// Erase the argument at 'index' and remove it from the argument list.
  void eraseArgument(unsigned index) { front().eraseArgument(index); }

  unsigned getNumArguments() { return getArguments().size(); }
  BlockArgument getArgument(unsigned i) { return getArguments()[i]; }

  //===--------------------------------------------------------------------===//
  // Operation list utilities
  //===--------------------------------------------------------------------===//
+2 −2
Original line number Diff line number Diff line
@@ -417,8 +417,8 @@ static LogicalResult processParallelLoop(

    if (isMappedToProcessor(processor)) {
      // Use the corresponding thread/grid index as replacement for the loop iv.
      Value operand = launchOp.body().front().getArgument(
          getLaunchOpArgumentNum(processor));
      Value operand =
          launchOp.body().getArgument(getLaunchOpArgumentNum(processor));
      // Take the indexmap and add the lower bound and step computations in.
      // This computes operand * step + lowerBound.
      // Use an affine map here so that it composes nicely with the provided
Loading