Commit 4f865b77 authored by Tobias Gysi's avatar Tobias Gysi Committed by Alex Zinenko
Browse files

[mlir] support creating memref descriptors from static shape with non-zero offset

This patch adapts the method MemRefDescriptor::fromStaticShape to
support static non-zero offsets. The updated method uses the
getStridesAndOffset method to extract strides and offset. The patch also
adapts the test cases since sizes and strides are now set in forward
instead of reverse order.

Differential Revision: https://reviews.llvm.org/D74474
parent 56aba969
Loading
Loading
Loading
Loading
+17 −11
Original line number Diff line number Diff line
@@ -430,7 +430,17 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
                                  LLVMTypeConverter &typeConverter,
                                  MemRefType type, Value memory) {
  assert(type.hasStaticShape() && "unexpected dynamic shape");
  assert(type.getAffineMaps().empty() && "unexpected layout map");

  // Extract all strides and offsets and verify they are static.
  int64_t offset;
  SmallVector<int64_t, 4> strides;
  auto result = getStridesAndOffset(type, strides, offset);
  (void)result;
  assert(succeeded(result) && "unexpected failure in stride computation");
  assert(offset != MemRefType::getDynamicStrideOrOffset() &&
         "expected static offset");
  assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
         "expected static strides");

  auto convertedType = typeConverter.convertType(type);
  assert(convertedType && "unexpected failure in memref type conversion");
@@ -438,16 +448,12 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
  auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
  descr.setAllocatedPtr(builder, loc, memory);
  descr.setAlignedPtr(builder, loc, memory);
  descr.setConstantOffset(builder, loc, 0);

  // Fill in sizes and strides, in reverse order to simplify stride
  // calculation.
  uint64_t runningStride = 1;
  for (unsigned i = type.getRank(); i > 0; --i) {
    unsigned dim = i - 1;
    descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
    descr.setConstantStride(builder, loc, dim, runningStride);
    runningStride *= type.getDimSize(dim);
  descr.setConstantOffset(builder, loc, offset);

  // Fill in sizes and strides
  for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
    descr.setConstantSize(builder, loc, i, type.getDimSize(i));
    descr.setConstantStride(builder, loc, i, strides[i]);
  }
  return descr;
}
+8 −8
Original line number Diff line number Diff line
@@ -92,18 +92,18 @@ gpu.module @kernel {
    // CHECK: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
    // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
    // CHECK: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c6]], %[[descr4]][3, 2]
    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 2]
    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0]
    // CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
    // CHECK: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1]
    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
    // CHECK: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1]
    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c4]], %[[descr8]][3, 0]
    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c12]], %[[descr9]][4, 0]
    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2]
    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]

    %c0 = constant 0 : index
    store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3>
+37 −9
Original line number Diff line number Diff line
@@ -24,20 +24,48 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
  return %static : memref<32x18xf32>
}

// -----

// CHECK-LABEL: func @check_static_return_with_offset
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-LABEL: func @check_static_return_with_offset
// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> {
// CHECK:  llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">

// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(22 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
  return %static : memref<32x18xf32, offset:7, strides:[22,1]>
}

// -----

// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
@@ -302,7 +330,7 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
// CHECK:        llvm.mlir.constant(42 : index) : !llvm.i64
// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
  %0 = dim %static, 0 : memref<42x32x15x13x27xf32>
// CHECK-NEXT:  llvm.mlir.constant(32 : index) : !llvm.i64