Unverified Commit f6012dd7 authored by Sergio Afonso's avatar Sergio Afonso Committed by GitHub
Browse files

[MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC (#161861)

This patch moves tablegen definitions that could be used for all kinds
of heap allocations out of `omp.target_allocmem` and into a new
`OpenMP_HeapAllocClause` that can be reused.

Descriptions are updated to follow the format of most other operations
and the custom verifier for `omp.target_allocmem` is removed as it only
made a redundant check on its result type.
parent b5c75514
Loading
Loading
Loading
Loading
+53 −0
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@
#define OPENMP_CLAUSES

include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"

@@ -579,6 +580,58 @@ class OpenMP_HasDeviceAddrClauseSkip<

def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;

//===----------------------------------------------------------------------===//
// Not in the spec: Clause-like structure to hold heap allocation information.
//===----------------------------------------------------------------------===//

class OpenMP_HeapAllocClauseSkip<
    bit traits = false, bit arguments = false, bit assemblyFormat = false,
    bit description = false, bit extraClassDeclaration = false
  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                    extraClassDeclaration> {
  let traits = [
    MemoryEffects<[MemAlloc<DefaultResource>]>
  ];

  let arguments = (ins
    TypeAttr:$in_type,
    OptionalAttr<StrAttr>:$uniq_name,
    OptionalAttr<StrAttr>:$bindc_name,
    Variadic<IntLikeType>:$typeparams,
    Variadic<IntLikeType>:$shape
  );

  // The custom parser doesn't parse `uniq_name` and `bindc_name`. This is
  // handled by the attr-dict, which must be present in the operation's
  // `assemblyFormat`.
  let reqAssemblyFormat = [{
    custom<HeapAllocClause>($in_type, $typeparams, type($typeparams), $shape,
                            type($shape))
  }];

  let extraClassDeclaration = [{
    mlir::Type getAllocatedType() { return getInTypeAttr().getValue(); }
  }];

  let description = [{
    The `in_type` is the type of the object for which memory is being allocated.
    For arrays, this can be a static or dynamic array type.

    The optional `uniq_name` is a unique name for the allocated memory.

    The optional `bindc_name` is a name used for C interoperability.

    The `typeparams` are runtime type parameters for polymorphic or
    parameterized types. These are typically integer values that define aspects
    of a type not fixed at compile time.

    The `shape` holds runtime shape operands for dynamic arrays. Each operand is
    an integer value representing the extent of a specific dimension.
  }];
}

def OpenMP_HeapAllocClause : OpenMP_HeapAllocClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [5.4.7] `inclusive` clause
//===----------------------------------------------------------------------===//
+33 −47
Original line number Diff line number Diff line
@@ -2247,25 +2247,18 @@ def AllocateFreeOp : OpenMP_Op<"allocate_free", [AttrSizedOperandSegments],
// TargetAllocMemOp
//===----------------------------------------------------------------------===//

def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
    [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
def TargetAllocMemOp : OpenMP_Op<"target_allocmem", traits = [
    AttrSizedOperandSegments
  ], clauses = [
    OpenMP_HeapAllocClause
  ]> {
  let summary = "allocate storage on an openmp device for an object of a given type";

  let description = [{
    Allocates memory on the specified OpenMP device for an object of the given type.
    Returns an integer value representing the device pointer to the allocated memory.
    The memory is uninitialized after allocation. Operations must be paired with 
    `omp.target_freemem` to avoid memory leaks.

    * `$device`: The integer ID of the OpenMP device where the memory will be allocated.
    * `$in_type`: The type of the object for which memory is being allocated. 
      For arrays, this can be a static or dynamic array type.
    * `$uniq_name`: An optional unique name for the allocated memory.
    * `$bindc_name`: An optional name used for C interoperability.
    * `$typeparams`: Runtime type parameters for polymorphic or parameterized types. 
      These are typically integer values that define aspects of a type not fixed at compile time.
    * `$shape`: Runtime shape operands for dynamic arrays. 
      Each operand is an integer value representing the extent of a specific dimension. 
    Allocates memory on the specified OpenMP device for an object of the given
    type. Returns an integer value representing the device pointer to the
    allocated memory. The memory is uninitialized after allocation. Operations
    must be paired with  `omp.target_freemem` to avoid memory leaks.

    ```mlir
      // Allocate a static 3x3 integer vector on device 0
@@ -2282,24 +2275,17 @@ def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
      // ... use %ptr_dynamic ...
      omp.target_freemem %device_1, %ptr_dynamic : i32, i64
    ```
  }];

  let arguments = (ins
    Arg<AnyInteger>:$device,
    TypeAttr:$in_type,
    OptionalAttr<StrAttr>:$uniq_name,
    OptionalAttr<StrAttr>:$bindc_name,
    Variadic<IntLikeType>:$typeparams,
    Variadic<IntLikeType>:$shape
  );
  let results = (outs I64);
    The `device` is an integer ID of the OpenMP device where the memory will be
    allocated.
  }] # clausesDescription;

  let hasCustomAssemblyFormat = 1;
  let hasVerifier = 1;
  let arguments = !con((ins Arg<AnyInteger>:$device), clausesArgs);
  let results = (outs I64);

  let extraClassDeclaration = [{
    mlir::Type getAllocatedType();
  }];
  // Override inherited assembly format to include `device`.
  let assemblyFormat = " $device `:` type($device) `,` "
                     # clausesReqAssemblyFormat # " attr-dict";
}

//===----------------------------------------------------------------------===//
+52 −101
Original line number Diff line number Diff line
@@ -874,6 +874,58 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
      p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}

//===----------------------------------------------------------------------===//
// Parser and printer for Heap Alloc Clause
//===----------------------------------------------------------------------===//

/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
static ParseResult parseHeapAllocClause(
    OpAsmParser &parser, TypeAttr &inTypeAttr,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &typeparams,
    SmallVectorImpl<Type> &typeparamsTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &shape,
    SmallVectorImpl<Type> &shapeTypes) {
  mlir::Type inType;
  if (parser.parseType(inType))
    return mlir::failure();
  inTypeAttr = TypeAttr::get(inType);

  if (!parser.parseOptionalLParen()) {
    // parse the LEN params of the derived type. (<params> : <types>)
    if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
        parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
      return failure();
  }

  if (!parser.parseOptionalComma()) {
    // parse size to scale by, vector of n dimensions of type index
    if (parser.parseOperandList(shape, OpAsmParser::Delimiter::None))
      return failure();

    // TODO: This overrides the actual types of the operands, which might cause
    // issues when they don't match. At the moment this is done in place of
    // making the corresponding operand type `Variadic<Index>` because index
    // types are lowered to I64 prior to LLVM IR translation.
    shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
  }

  return success();
}

static void printHeapAllocClause(OpAsmPrinter &p, Operation *op,
                                 TypeAttr inType, ValueRange typeparams,
                                 TypeRange typeparamsTypes, ValueRange shape,
                                 TypeRange shapeTypes) {
  p << inType;
  if (!typeparams.empty()) {
    p << '(' << typeparams << " : " << typeparamsTypes << ')';
  }
  for (auto sh : shape) {
    p << ", ";
    p.printOperand(sh);
  }
}

//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -4651,107 +4703,6 @@ LogicalResult AllocateDirOp::verify() {
  return success();
}

//===----------------------------------------------------------------------===//
// TargetAllocMemOp
//===----------------------------------------------------------------------===//

mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
  return getInTypeAttr().getValue();
}

/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
///                      $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
///                      attr-dict-without-keyword
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
                                               mlir::OperationState &result) {
  auto &builder = parser.getBuilder();
  bool hasOperands = false;
  std::int32_t typeparamsSize = 0;

  // Parse device number as a new operand
  mlir::OpAsmParser::UnresolvedOperand deviceOperand;
  mlir::Type deviceType;
  if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
    return mlir::failure();
  if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
    return mlir::failure();
  if (parser.parseComma())
    return mlir::failure();

  mlir::Type intype;
  if (parser.parseType(intype))
    return mlir::failure();
  result.addAttribute("in_type", mlir::TypeAttr::get(intype));
  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
  llvm::SmallVector<mlir::Type> typeVec;
  if (!parser.parseOptionalLParen()) {
    // parse the LEN params of the derived type. (<params> : <types>)
    if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
        parser.parseColonTypeList(typeVec) || parser.parseRParen())
      return mlir::failure();
    typeparamsSize = operands.size();
    hasOperands = true;
  }
  std::int32_t shapeSize = 0;
  if (!parser.parseOptionalComma()) {
    // parse size to scale by, vector of n dimensions of type index
    if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
      return mlir::failure();
    shapeSize = operands.size() - typeparamsSize;
    auto idxTy = builder.getIndexType();
    for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
      typeVec.push_back(idxTy);
    hasOperands = true;
  }
  if (hasOperands &&
      parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
                             result.operands))
    return mlir::failure();

  mlir::Type restype = builder.getIntegerType(64);
  if (!restype) {
    parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
    return mlir::failure();
  }
  llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
  result.addAttribute("operandSegmentSizes",
                      builder.getDenseI32ArrayAttr(segmentSizes));
  if (parser.parseOptionalAttrDict(result.attributes) ||
      parser.addTypeToList(restype, result.types))
    return mlir::failure();
  return mlir::success();
}

mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
                                               mlir::OperationState &result) {
  return parseTargetAllocMemOp(parser, result);
}

void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
  p << " ";
  p.printOperand(getDevice());
  p << " : ";
  p << getDevice().getType();
  p << ", ";
  p << getInType();
  if (!getTypeparams().empty()) {
    p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
  }
  for (auto sh : getShape()) {
    p << ", ";
    p.printOperand(sh);
  }
  p.printOptionalAttrDict((*this)->getAttrs(),
                          {"in_type", "operandSegmentSizes"});
}

llvm::LogicalResult omp::TargetAllocMemOp::verify() {
  mlir::Type outType = getType();
  if (!mlir::dyn_cast<IntegerType>(outType))
    return emitOpError("must be a integer type");
  return mlir::success();
}

//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//
+14 −0
Original line number Diff line number Diff line
@@ -3449,3 +3449,17 @@ func.func @iterator_yield_type_mismatch(%lb : index, %ub : index, %st : index) {
  } -> !omp.iterated<i64>
  return
}

// -----
func.func @target_allocmem_invalid_uniq_name(%device : i32) -> () {
// expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
  %0 = omp.target_allocmem %device : i32, i64 {uniq_name=2}
  return
}

// -----
func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
// expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
  %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
  return
}
+24 −0
Original line number Diff line number Diff line
@@ -3905,3 +3905,27 @@ func.func @omp_task_affinity_iterator_2d(%lb0 : index, %ub0 : index, %st0 : inde

  return
}

// CHECK-LABEL: func.func @omp_target_allocmem(
// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
  // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, i64
  %0 = omp.target_allocmem %device : i32, i64
  // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"}
  %1 = omp.target_allocmem %device : i32, vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"}
  // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32)
  %2 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32)
  // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr, %[[X]], %[[Y]]
  %3 = omp.target_allocmem %device : i32, !llvm.ptr, %x, %y
  // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]]
  %4 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y
  return
}

// CHECK-LABEL: func.func @omp_target_freemem(
// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
func.func @omp_target_freemem(%device : i32, %ptr : i64) {
  // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
  omp.target_freemem %device, %ptr : i32, i64
  return
}