Unverified Commit 510380a4 authored by Charitha Saumya's avatar Charitha Saumya Committed by GitHub
Browse files

[mlir][xegpu] Support for nD memrefs in `vector.transfer_read` with transposed...

[mlir][xegpu] Support for nD memrefs in `vector.transfer_read` with transposed permutation maps.  (#195197)

Current implementation fails when loading from a 3d memref with perm map
(d0, d1, d2) -> (d2, d1).
parent c5771226
Loading
Loading
Loading
Loading
+29 −11
Original line number Diff line number Diff line
@@ -561,7 +561,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
      AffineMap readMap = readOp.getPermutationMap();
      if (!readMap.isMinorIdentity())
        return rewriter.notifyMatchFailure(
            readOp, "Transpose not supported for SLM loads");
            readOp,
            "Non identity transposition is not supported for SLM loads.");
      // Out of bounds case is not supported for SLM loads.
      if (isOutOfBounds)
        return rewriter.notifyMatchFailure(
@@ -613,16 +614,33 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
          readOp, "Unsupported non-zero padded out-of-bounds read");

    AffineMap readMap = readOp.getPermutationMap();
    bool isTransposeLoad = !readMap.isMinorIdentity();
    // Check if this is a transpose: the map must have exactly 2 results,
    // and those 2 results must be the last 2 input dimensions interchanged.
    // Examples:
    //   (d0, d1) -> (d1, d0)      // transpose
    //   (d0, d1) -> (d0, d1)      // not a transpose
    //   (d0, d1, d2) -> (d2, d1)  // transpose (last 2 dims swapped)
    bool isTransposeLoad = false;
    if (readMap.getNumResults() == 2) {
      auto results = readMap.getResults();
      unsigned numInputs = readMap.getNumInputs();
      if (numInputs >= 2) {
        auto lastDim = getAffineDimExpr(numInputs - 1, readMap.getContext());
        auto secondLastDim =
            getAffineDimExpr(numInputs - 2, readMap.getContext());
        isTransposeLoad =
            (results[0] == lastDim && results[1] == secondLastDim);
      }
    }
    auto elementType = loadedVecTy.getElementType();

    SmallVector<int64_t> descShape(loadedVecTy.getShape());
    if (isTransposeLoad) {
      // If load is transposed, then the shape of the source-descriptor
      // is the opposite from the result-shape. Applying the permutation
      // to get the reversive shape.
      auto inversedMap = inversePermutation(readMap);
      descShape = applyPermutationMap(inversedMap, loadedVecTy.getShape());
      // If load is transposed, simply swap the last two dimensions of the
      // loaded vector type to get the descriptor shape.
      size_t rank = descShape.size();
      assert(rank >= 2 && "Transpose requires at least 2 dimensions");
      std::swap(descShape[rank - 1], descShape[rank - 2]);
      loadedVecTy = VectorType::get(descShape, elementType);
    }
    auto descType = xegpu::TensorDescType::get(
@@ -646,10 +664,10 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
      // Transposing the loaded vector with a separate vector.transpose
      // operation
      auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
      SmallVector<int64_t> perm(range.begin(), range.end());
      auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
      loadedOp = vector::TransposeOp::create(
          rewriter, loc, loadedOp->getResult(0), permApplied);
      SmallVector<int64_t> perm(
          range.rbegin(), range.rend()); // reverse the range for transpose
      loadedOp = vector::TransposeOp::create(rewriter, loc,
                                             loadedOp->getResult(0), perm);
    }
    rewriter.replaceOp(readOp, loadedOp);

+40 −0
Original line number Diff line number Diff line
@@ -142,6 +142,46 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,

}

// -----
gpu.module @xevm_module {
gpu.func @load_transpose_3d_memref(%source: memref<32x64x128xf32>,
    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%i, %j, %k], %c0
    {permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>,
    in_bounds = [true, true]} : memref<32x64x128xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @load_transpose_3d_memref(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64x128xf32>,
// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index) -> vector<8x16xf32> {
// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND-SAME:     : memref<f32> -> index
// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [64, 128],
// LOAD-ND-SAME:                   strides : [128, 1] : i64 -> !xegpu.tensor_desc<16x8xf32,
// LOAD-ND-SAME:     #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]
// LOAD-ND-SAME:     : !xegpu.tensor_desc<16x8xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<16x8xf32>
// LOAD-ND:        %[[VEC_TRANSPOSED:.+]] = vector.transpose %[[VEC]], [1, 0] : vector<16x8xf32> to vector<8x16xf32>

// LOAD-GATHER-LABEL:  @load_transpose_3d_memref(
// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64x128xf32>,
// LOAD-GATHER-SAME:    %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index
// LOAD-GATHER:         %[[BCAST3:.+]] = vector.broadcast %{{.*}} : index to vector<8x16xindex>
// LOAD-GATHER:         %[[IDX:.+]] = arith.addi %[[BCAST3]], %{{.*}} : vector<8x16xindex>
// LOAD-GATHER:         %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64x128xf32> -> index
// LOAD-GATHER-NEXT:    %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// LOAD-GATHER-NEXT:    %[[LOAD:.*]] = xegpu.load %[[I64PTR]][%[[IDX]]], %{{.*}} : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}

// -----
gpu.module @xevm_module {
gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,