Unverified Commit f5c0d917 authored by Akhil Goel's avatar Akhil Goel Committed by GitHub
Browse files

[X86][AVX10.2] Use SDNode patterns based lowering for VMINBF16/VMAXBF16 (#194987)

This PR adds direct SDNode-based selection for AVX10.2 BF16 vmin/vmax.
This unblocks the select-minmax DAG combine which would earlier hit a
selection failure.
parent ce7fbcbf
Loading
Loading
Loading
Loading
+2 −24
Original line number Diff line number Diff line
@@ -980,28 +980,6 @@ defm VCVTHF82PH : avx10_convert_2op_nomb<"vcvthf82ph", avx512vl_f16_info,
//-------------------------------------------------

// VADDBF16, VSUBBF16, VMULBF16, VDIVBF16, VMAXBF16, VMINBF16
multiclass avx10_fp_binop_int_bf16<bits<8> opc, string OpcodeStr,
                                      X86SchedWriteSizes sched,
                                      bit IsCommutable = 0> {
  let Predicates = [HasAVX10_2] in {
    defm Z : avx512_fp_packed<opc, OpcodeStr,
                              !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16512"),
                              !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16512"),
                              v32bf16_info, sched.PH.ZMM, IsCommutable>, EVEX_V512,
                              T_MAP5, PD, EVEX_CD8<16, CD8VF>;
    defm Z128 : avx512_fp_packed<opc, OpcodeStr,
                                 !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16128"),
                                 !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16128"),
                                 v8bf16x_info, sched.PH.XMM, IsCommutable>, EVEX_V128,
                                 T_MAP5, PD, EVEX_CD8<16, CD8VF>;
    defm Z256 : avx512_fp_packed<opc, OpcodeStr,
                                 !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16256"),
                                 !cast<Intrinsic>("int_x86_avx10_"#OpcodeStr#"bf16256"),
                                 v16bf16x_info, sched.PH.YMM, IsCommutable>, EVEX_V256,
                                 T_MAP5, PD, EVEX_CD8<16, CD8VF>;
  }
}

multiclass avx10_fp_binop_bf16<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
                                X86SchedWriteSizes sched,
                                bit IsCommutable = 0,
@@ -1024,8 +1002,8 @@ defm VADDBF16 : avx10_fp_binop_bf16<0x58, "vadd", fadd, SchedWriteFAddSizes, 1>;
defm VSUBBF16 : avx10_fp_binop_bf16<0x5C, "vsub", fsub, SchedWriteFAddSizes, 0>;
defm VMULBF16 : avx10_fp_binop_bf16<0x59, "vmul", fmul, SchedWriteFMulSizes, 1>;
defm VDIVBF16 : avx10_fp_binop_bf16<0x5E, "vdiv", fdiv, SchedWriteFDivSizes, 0>;
defm VMINBF16 : avx10_fp_binop_int_bf16<0x5D, "vmin", SchedWriteFCmpSizes, 0>;
defm VMAXBF16 : avx10_fp_binop_int_bf16<0x5F, "vmax", SchedWriteFCmpSizes, 0>;
defm VMINBF16 : avx10_fp_binop_bf16<0x5D, "vmin", X86fmin, SchedWriteFCmpSizes, 0>;
defm VMAXBF16 : avx10_fp_binop_bf16<0x5F, "vmax", X86fmax, SchedWriteFCmpSizes, 0>;
}

// VCOMISBF16
+6 −0
Original line number Diff line number Diff line
@@ -693,6 +693,12 @@ static const IntrinsicData IntrinsicsWithoutChain[] = {
    X86_INTRINSIC_DATA(avx10_vdpphps_128, INTR_TYPE_3OP, X86ISD::DPFP16PS, 0),
    X86_INTRINSIC_DATA(avx10_vdpphps_256, INTR_TYPE_3OP, X86ISD::DPFP16PS, 0),
    X86_INTRINSIC_DATA(avx10_vdpphps_512, INTR_TYPE_3OP, X86ISD::DPFP16PS, 0),
    X86_INTRINSIC_DATA(avx10_vmaxbf16128, INTR_TYPE_2OP, X86ISD::FMAX, 0),
    X86_INTRINSIC_DATA(avx10_vmaxbf16256, INTR_TYPE_2OP, X86ISD::FMAX, 0),
    X86_INTRINSIC_DATA(avx10_vmaxbf16512, INTR_TYPE_2OP, X86ISD::FMAX, 0),
    X86_INTRINSIC_DATA(avx10_vminbf16128, INTR_TYPE_2OP, X86ISD::FMIN, 0),
    X86_INTRINSIC_DATA(avx10_vminbf16256, INTR_TYPE_2OP, X86ISD::FMIN, 0),
    X86_INTRINSIC_DATA(avx10_vminbf16512, INTR_TYPE_2OP, X86ISD::FMIN, 0),
    X86_INTRINSIC_DATA(avx10_vminmaxbf16128, INTR_TYPE_3OP, X86ISD::VMINMAX, 0),
    X86_INTRINSIC_DATA(avx10_vminmaxbf16256, INTR_TYPE_3OP, X86ISD::VMINMAX, 0),
    X86_INTRINSIC_DATA(avx10_vminmaxbf16512, INTR_TYPE_3OP, X86ISD::VMINMAX, 0),
+92 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx10.1 | FileCheck %s --check-prefixes=CHECK,AVX10_1
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx10.2 | FileCheck %s --check-prefixes=CHECK,AVX10_2

define bfloat @select_ogt_bf16(bfloat %a, bfloat %b) {
; CHECK-LABEL: select_ogt_bf16:
; CHECK:       # %bb.0:
; CHECK-NEXT:    vmovw %xmm0, %eax
; CHECK-NEXT:    vmovw %xmm1, %ecx
; CHECK-NEXT:    movl %ecx, %edx
; CHECK-NEXT:    shll $16, %edx
; CHECK-NEXT:    vmovd %edx, %xmm0
; CHECK-NEXT:    movl %eax, %edx
; CHECK-NEXT:    shll $16, %edx
; CHECK-NEXT:    vmovd %edx, %xmm1
; CHECK-NEXT:    vucomiss %xmm0, %xmm1
; CHECK-NEXT:    cmoval %eax, %ecx
; CHECK-NEXT:    vmovw %ecx, %xmm0
; CHECK-NEXT:    retq
  %cmp = fcmp ogt bfloat %a, %b
  %sel = select i1 %cmp, bfloat %a, bfloat %b
  ret bfloat %sel
}

define <8 x bfloat> @select_olt_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
; AVX10_1-LABEL: select_olt_v8bf16:
; AVX10_1:       # %bb.0:
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} ymm2 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
; AVX10_1-NEXT:    vpslld $16, %ymm2, %ymm2
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} ymm3 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
; AVX10_1-NEXT:    vpslld $16, %ymm3, %ymm3
; AVX10_1-NEXT:    vcmpltps %ymm2, %ymm3, %k1
; AVX10_1-NEXT:    vpblendmw %xmm0, %xmm1, %xmm0 {%k1}
; AVX10_1-NEXT:    vzeroupper
; AVX10_1-NEXT:    retq
;
; AVX10_2-LABEL: select_olt_v8bf16:
; AVX10_2:       # %bb.0:
; AVX10_2-NEXT:    vminbf16 %xmm1, %xmm0, %xmm0
; AVX10_2-NEXT:    retq
  %cmp = fcmp olt <8 x bfloat> %a, %b
  %sel = select <8 x i1> %cmp, <8 x bfloat> %a, <8 x bfloat> %b
  ret <8 x bfloat> %sel
}

define <16 x bfloat> @select_ogt_v16bf16(<16 x bfloat> %a, <16 x bfloat> %b) {
; AVX10_1-LABEL: select_ogt_v16bf16:
; AVX10_1:       # %bb.0:
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm2 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm2, %zmm2
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm3 = ymm1[0],zero,ymm1[1],zero,ymm1[2],zero,ymm1[3],zero,ymm1[4],zero,ymm1[5],zero,ymm1[6],zero,ymm1[7],zero,ymm1[8],zero,ymm1[9],zero,ymm1[10],zero,ymm1[11],zero,ymm1[12],zero,ymm1[13],zero,ymm1[14],zero,ymm1[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm3, %zmm3
; AVX10_1-NEXT:    vcmpltps %zmm2, %zmm3, %k1
; AVX10_1-NEXT:    vpblendmw %ymm0, %ymm1, %ymm0 {%k1}
; AVX10_1-NEXT:    retq
;
; AVX10_2-LABEL: select_ogt_v16bf16:
; AVX10_2:       # %bb.0:
; AVX10_2-NEXT:    vmaxbf16 %ymm1, %ymm0, %ymm0
; AVX10_2-NEXT:    retq
  %cmp = fcmp ogt <16 x bfloat> %a, %b
  %sel = select <16 x i1> %cmp, <16 x bfloat> %a, <16 x bfloat> %b
  ret <16 x bfloat> %sel
}

define <32 x bfloat> @select_olt_v32bf16(<32 x bfloat> %a, <32 x bfloat> %b) {
; AVX10_1-LABEL: select_olt_v32bf16:
; AVX10_1:       # %bb.0:
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm2 = ymm1[0],zero,ymm1[1],zero,ymm1[2],zero,ymm1[3],zero,ymm1[4],zero,ymm1[5],zero,ymm1[6],zero,ymm1[7],zero,ymm1[8],zero,ymm1[9],zero,ymm1[10],zero,ymm1[11],zero,ymm1[12],zero,ymm1[13],zero,ymm1[14],zero,ymm1[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm2, %zmm2
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm3 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm3, %zmm3
; AVX10_1-NEXT:    vcmpltps %zmm2, %zmm3, %k0
; AVX10_1-NEXT:    vextracti64x4 $1, %zmm1, %ymm2
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm2 = ymm2[0],zero,ymm2[1],zero,ymm2[2],zero,ymm2[3],zero,ymm2[4],zero,ymm2[5],zero,ymm2[6],zero,ymm2[7],zero,ymm2[8],zero,ymm2[9],zero,ymm2[10],zero,ymm2[11],zero,ymm2[12],zero,ymm2[13],zero,ymm2[14],zero,ymm2[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm2, %zmm2
; AVX10_1-NEXT:    vextracti64x4 $1, %zmm0, %ymm3
; AVX10_1-NEXT:    vpmovzxwd {{.*#+}} zmm3 = ymm3[0],zero,ymm3[1],zero,ymm3[2],zero,ymm3[3],zero,ymm3[4],zero,ymm3[5],zero,ymm3[6],zero,ymm3[7],zero,ymm3[8],zero,ymm3[9],zero,ymm3[10],zero,ymm3[11],zero,ymm3[12],zero,ymm3[13],zero,ymm3[14],zero,ymm3[15],zero
; AVX10_1-NEXT:    vpslld $16, %zmm3, %zmm3
; AVX10_1-NEXT:    vcmpltps %zmm2, %zmm3, %k1
; AVX10_1-NEXT:    kunpckwd %k0, %k1, %k1
; AVX10_1-NEXT:    vpblendmw %zmm0, %zmm1, %zmm0 {%k1}
; AVX10_1-NEXT:    retq
;
; AVX10_2-LABEL: select_olt_v32bf16:
; AVX10_2:       # %bb.0:
; AVX10_2-NEXT:    vminbf16 %zmm1, %zmm0, %zmm0
; AVX10_2-NEXT:    retq
  %cmp = fcmp olt <32 x bfloat> %a, %b
  %sel = select <32 x i1> %cmp, <32 x bfloat> %a, <32 x bfloat> %b
  ret <32 x bfloat> %sel
}