Unverified Commit dc19e4b0 authored by Hao Ren's avatar Hao Ren Committed by GitHub
Browse files

[mlir][NVGPUToNVVM] Support BF16 mma.sync lowering (#194203)



Let NVGPUToNVVM to recognize BF16 MMA operand element types
Pack `vector<2xbf16>` fragments to `i32` before emitting
`nvvm.mma.sync`.
This matches the PTX operand encoding for `m16n8k16` BF16 MMA
instructions.

Add a conversion test for `nvgpu.mma.sync` `bf16xbf16` to `f32`
lowering.

Co-authored-by: default avatarHao Ren <rhao8608@gmail.com>
parent cab0c0dd
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -179,6 +179,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
  Type f64Ty = b.getF64Type();
  Type f32Ty = b.getF32Type();
  Type i64Ty = b.getI64Type();
  Type bf16x2Ty = VectorType::get(2, b.getBF16Type());
  Type i8x4Ty = VectorType::get(4, b.getI8Type());
  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
  Type f32x1Ty = VectorType::get(1, f32Ty);
@@ -191,6 +192,8 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
    // scalar types.
    if (arrayTy.getElementType() == i8x4Ty ||
        arrayTy.getElementType() == i4x8Ty ||
        (arrayTy.getElementType() == bf16x2Ty &&
         operandPtxType == NVVM::MMATypes::bf16) ||
        (arrayTy.getElementType() == f32x1Ty &&
         operandPtxType == NVVM::MMATypes::tf32)) {
      result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
@@ -320,6 +323,8 @@ static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
    return NVVM::MMATypes::s4;
  if (elType.isF16())
    return NVVM::MMATypes::f16;
  if (elType.isBF16())
    return NVVM::MMATypes::bf16;
  if (elType.isF64())
    return NVVM::MMATypes::f64;
  if (elType.isF32())
+23 −0
Original line number Diff line number Diff line
@@ -49,6 +49,29 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
  return %d : vector<2x2xf32>
}

// CHECK-LABEL: @m16n8k16_bf16_fp32
func.func @m16n8k16_bf16_fp32(%arg0: vector<4x2xbf16>, %arg1: vector<2x2xbf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xbf16>>
  // CHECK: llvm.bitcast {{.*}} : vector<2xbf16> to i32
  // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}] B[{{%.+}}, {{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}]
  // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<bf16>
  // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<bf16>
  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
  // CHECK-SAME: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xbf16>, vector<2x2xbf16>, vector<2x2xf32>) -> vector<2x2xf32>
  return %d : vector<2x2xf32>
}

// CHECK-LABEL: @m16n8k8_fp16
func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>