Commit 2eff566b authored by Stephan Herhut's avatar Stephan Herhut
Browse files

[MLIR] Add `and`, `or`, `xor`, `min`, `max` too gpu.all_reduce and the nvvm lowering

Summary:
This patch add some builtin operation for the gpu.all_reduce ops.
- for Integer only: `and`, `or`, `xor`
- for Float and Integer: `min`, `max`

This is useful for higher level dialect like OpenACC or OpenMP that can lower to the GPU dialect.

Differential Revision: https://reviews.llvm.org/D75766
parent 7fb562c1
Loading
Loading
Loading
Loading
+13 −3
Original line number Diff line number Diff line
@@ -482,15 +482,25 @@ def GPU_YieldOp : GPU_Op<"yield", [Terminator]>,
  }];
}

// These mirror the XLA ComparisonDirection enum.
// add, mul mirror the XLA ComparisonDirection enum.
def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;

def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
    "built-in reduction operations supported by gpu.allreduce.",
    [
      GPU_AllReduceOpAdd,
      GPU_AllReduceOpAnd,
      GPU_AllReduceOpMax,
      GPU_AllReduceOpMin,
      GPU_AllReduceOpMul,
      GPU_AllReduceOpOr,
      GPU_AllReduceOpXor
    ]>;

def GPU_AllReduceOp : GPU_Op<"all_reduce",
@@ -514,8 +524,8 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
    ```
    compute the sum of each work item's %0 value. The first version specifies
    the accumulation as operation, whereas the second version specifies the
    accumulation as code region. The accumulation operation must either be
    `add` or `mul`.
    accumulation as code region. The accumulation operation must be one of:
    `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.

    Either none or all work items of a workgroup need to execute this op
    in convergence.
+2 −0
Original line number Diff line number Diff line
@@ -211,6 +211,8 @@ _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
extern "C" MLIR_RUNNERUTILS_EXPORT void
_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);

extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank,
                                                         void *ptr);
extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank,
                                                         void *ptr);

+36 −3
Original line number Diff line number Diff line
@@ -123,18 +123,51 @@ private:
      return isFloatingPoint ? getFactory<LLVM::FMulOp>()
                             : getFactory<LLVM::MulOp>();
    }
    if (opName == "and") {
      return getFactory<LLVM::AndOp>();
    }
    if (opName == "or") {
      return getFactory<LLVM::OrOp>();
    }
    if (opName == "xor") {
      return getFactory<LLVM::XOrOp>();
    }
    if (opName == "max") {
      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
                                             LLVM::FCmpPredicate::ugt>()
                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
                                             LLVM::ICmpPredicate::ugt>();
    }
    if (opName == "min") {
      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
                                             LLVM::FCmpPredicate::ult>()
                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
                                             LLVM::ICmpPredicate::ult>();
    }

    return AccumulatorFactory();
  }

  /// Returns an accumulator factory that creates an op of type T.
  template <typename T> AccumulatorFactory getFactory() const {
  template <typename T>
  AccumulatorFactory getFactory() const {
    return [](Location loc, Value lhs, Value rhs,
              ConversionPatternRewriter &rewriter) {
      return rewriter.create<T>(loc, lhs.getType(), lhs, rhs);
    };
  }

  /// Returns an accumulator for comparaison such as min, max. T is the type
  /// of the compare op.
  template <typename T, typename PredicateEnum, PredicateEnum predicate>
  AccumulatorFactory getCmpFactory() const {
    return [](Location loc, Value lhs, Value rhs,
              ConversionPatternRewriter &rewriter) {
      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
      return rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
    };
  }

  /// Creates an all_reduce across the block.
  ///
  /// First reduce the elements within a warp. The first thread of each warp
+8 −0
Original line number Diff line number Diff line
@@ -148,6 +148,14 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
    }
    if (yieldCount == 0)
      return allReduce.emitError("expected gpu.yield op in region");
  } else {
    StringRef opName = *allReduce.op();
    if ((opName == "and" || opName == "or" || opName == "xor") &&
        !allReduce.getType().isa<IntegerType>()) {
      return allReduce.emitError()
             << '`' << opName << '`'
             << " accumulator is only compatible with Integer type";
    }
  }
  return success();
}
+29 −0
Original line number Diff line number Diff line
@@ -212,6 +212,25 @@ private:
      return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
    if (opName == "mul")
      return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
    if (opName == "and") {
      return getFactory<AndOp>();
    }
    if (opName == "or") {
      return getFactory<OrOp>();
    }
    if (opName == "xor") {
      return getFactory<XOrOp>();
    }
    if (opName == "max") {
      return isFloatingPoint
                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
    }
    if (opName == "min") {
      return isFloatingPoint
                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
    }
    return AccumulatorFactory();
  }

@@ -222,6 +241,16 @@ private:
    };
  }

  /// Returns an accumulator for comparaison such as min, max. T is the type
  /// of the compare op.
  template <typename T, typename PredicateEnum, PredicateEnum predicate>
  AccumulatorFactory getCmpFactory() const {
    return [&](Value lhs, Value rhs) {
      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
      return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
    };
  }

  /// Creates an if-block skeleton and calls the two factories to generate the
  /// ops in the `then` and `else` block..
  ///
Loading