Unverified Commit cab0c0dd authored by Varad Rahul Kamthe's avatar Varad Rahul Kamthe Committed by GitHub
Browse files

[MLIR][NVVM] Add movmatrix Op (#193995)

Add `movmatrix` to MLIR NVVM dialect, which moves a row-major matrix across all threads in a warp and writes the
transposed elements to the destination.
parent 796d2ec4
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -3326,6 +3326,36 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
  let hasVerifier = 1;
}

def NVVM_MovMatrixOp
    : NVVM_SingleResultIntrinsicOp<"movmatrix", [NVVMRequiresSM<75>], "$dst"> {
  let summary = "Warp-level matrix transpose";
  let description = [{
    Moves a row-major matrix across all threads in a warp, reading elements
    from source `$src`, and writing the transposed elements to destination
    `$dst`.

    The `shape` attribute indicates the dimensions of the matrix being
    transposed. Each matrix element holds 16-bit data as indicated by the
    `eltType` attribute.

    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-movmatrix-instruction)

    Example:
    ```mlir
    %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
                                eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
    ```
  }];

  let results = (outs I32:$dst);
  let arguments = (ins I32:$src, LdStMatrixShapeAttr:$shape,
      DefaultValuedAttr<MMALayoutAttr, "MMALayout::col">:$layout,
      LdStMatrixEltTypeAttr:$eltType);

  let assemblyFormat = "$src attr-dict `:` type($src)";
  let hasVerifier = 1;
}

def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {

  let summary = "cooperative matrix-multiply and accumulate";
+19 −0
Original line number Diff line number Diff line
@@ -2607,6 +2607,17 @@ LogicalResult NVVM::StMatrixOp::verify() {
  return success();
}

LogicalResult NVVM::MovMatrixOp::verify() {
  int m = getShape().getM(), n = getShape().getN();
  if (m != 8 || n != 8)
    return emitOpError("expected shape to be 8x8");
  if (getLayout() != NVVM::MMALayout::col)
    return emitOpError("expected layout to be col");
  if (getEltType() != NVVM::LdStMatrixEltType::B16)
    return emitOpError("expected element type to be b16");
  return success();
}

static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
  if (typeA == NVVM::WGMMATypes::tf32)
    return 8;
@@ -3884,6 +3895,14 @@ mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
  return {id, {mt.lookupValue(thisOp.getAddr())}};
}

mlir::NVVM::IDArgPair
MovMatrixOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
                                   llvm::IRBuilderBase &builder) {
  auto thisOp = cast<NVVM::MovMatrixOp>(op);
  return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
          {mt.lookupValue(thisOp.getSrc())}};
}

#define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix

+8 −0
Original line number Diff line number Diff line
@@ -133,6 +133,14 @@ func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
  llvm.return %0 : i32
}

// CHECK-LABEL: @nvvm_movmatrix
func.func @nvvm_movmatrix(%src : i32) -> i32 {
  // CHECK: nvvm.movmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : i32
  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
  llvm.return %dst : i32
}

// CHECK-LABEL: @llvm_nvvm_bar_warp_sync
func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
  // CHECK: nvvm.bar.warp.sync %{{.*}}
+28 −0
Original line number Diff line number Diff line
@@ -541,6 +541,34 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {

// -----

llvm.func @mov_matrix(%src : i32) -> i32 {
  // expected-error@+1 {{'nvvm.movmatrix' op expected shape to be 8x8}}
  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,
                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
  llvm.return %dst : i32
}

// -----

llvm.func @mov_matrix(%src : i32) -> i32 {
  // expected-error@+1 {{'nvvm.movmatrix' op expected layout to be col}}
  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
                              layout = #nvvm.mma_layout<row>,
                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
  llvm.return %dst : i32
}

// -----

llvm.func @mov_matrix(%src : i32) -> i32 {
  // expected-error@+1 {{'nvvm.movmatrix' op expected element type to be b16}}
  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
                              eltType = #nvvm.ld_st_matrix_elt_type<b8>} : i32
  llvm.return %dst : i32
}

// -----

llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) {
  // expected-error@+1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}}
  %res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32
+8 −0
Original line number Diff line number Diff line
@@ -617,6 +617,14 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32
  llvm.return
}

// CHECK-LABEL: @nvvm_movmatrix
llvm.func @nvvm_movmatrix(%src : i32) -> i32 {
  // CHECK: call i32 @llvm.nvvm.movmatrix.sync.aligned.m8n8.trans.b16(i32 %{{.*}})
  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
  llvm.return %dst : i32
}

// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {