Commit a6d9c944 authored by Hassnaa Hamdi's avatar Hassnaa Hamdi
Browse files

[AArch64 - SVE]: Use SVE to lower reduce.fadd.

Differential Revision: https://reviews.llvm.org/D132573

skip custom-lowering for v1f64 to be expanded instead, because it has only one lane

Differential Revision: https://reviews.llvm.org/D132959
parent 1ed555a6
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -146,7 +146,7 @@ namespace llvm {
      v64f16         =  86,   //   64 x f16
      v128f16        =  87,   //  128 x f16
      v256f16        =  88,   //  256 x f16
      v512f16        =  89,   //  256 x f16
      v512f16        =  89,   //  512 x f16

      v2bf16         =  90,   //    2 x bf16
      v3bf16         =  91,   //    3 x bf16
+4 −4
Original line number Diff line number Diff line
@@ -1380,6 +1380,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
    setOperationAction(ISD::MUL, MVT::v1i64, Custom);
    setOperationAction(ISD::MUL, MVT::v2i64, Custom);
    // NEON doesn't support across-vector reductions, but SVE does.
    for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v2f64})
      setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
    // NOTE: Currently this has to happen after computeRegisterProperties rather
    // than the preferred option of combining it with the addRegisterClass call.
    if (Subtarget->useSVEForFixedLengthVectors()) {
@@ -1433,10 +1437,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
        setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
      }
      // FP operations with no NEON support.
      for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32,
                      MVT::v1f64, MVT::v2f64})
        setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
      // Use SVE for vectors with more than 2 elements.
      for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
+25 −37
Original line number Diff line number Diff line
@@ -13,14 +13,11 @@ target triple = "aarch64-unknown-linux-gnu"
define half @fadda_v4f16(half %start, <4 x half> %a) vscale_range(1,0) #0 {
; CHECK-LABEL: fadda_v4f16:
; CHECK:       // %bb.0:
; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
; CHECK-NEXT:    mov h2, v1.h[1]
; CHECK-NEXT:    fadd h0, h0, h1
; CHECK-NEXT:    mov h3, v1.h[2]
; CHECK-NEXT:    mov h1, v1.h[3]
; CHECK-NEXT:    fadd h0, h0, h2
; CHECK-NEXT:    fadd h0, h0, h3
; CHECK-NEXT:    fadd h0, h0, h1
; CHECK-NEXT:    // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT:    ptrue p0.h, vl4
; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT:    fadda h0, p0, h0, z1.h
; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT:    ret
  %res = call half @llvm.vector.reduce.fadd.v4f16(half %start, <4 x half> %a)
  ret half %res
@@ -30,21 +27,11 @@ define half @fadda_v4f16(half %start, <4 x half> %a) vscale_range(1,0) #0 {
define half @fadda_v8f16(half %start, <8 x half> %a) vscale_range(1,0) #0 {
; CHECK-LABEL: fadda_v8f16:
; CHECK:       // %bb.0:
; CHECK-NEXT:    mov h2, v1.h[1]
; CHECK-NEXT:    fadd h0, h0, h1
; CHECK-NEXT:    mov h3, v1.h[2]
; CHECK-NEXT:    fadd h0, h0, h2
; CHECK-NEXT:    mov h2, v1.h[3]
; CHECK-NEXT:    fadd h0, h0, h3
; CHECK-NEXT:    mov h3, v1.h[4]
; CHECK-NEXT:    fadd h0, h0, h2
; CHECK-NEXT:    mov h2, v1.h[5]
; CHECK-NEXT:    fadd h0, h0, h3
; CHECK-NEXT:    mov h3, v1.h[6]
; CHECK-NEXT:    mov h1, v1.h[7]
; CHECK-NEXT:    fadd h0, h0, h2
; CHECK-NEXT:    fadd h0, h0, h3
; CHECK-NEXT:    fadd h0, h0, h1
; CHECK-NEXT:    // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT:    ptrue p0.h, vl8
; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT:    fadda h0, p0, h0, z1.h
; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT:    ret
  %res = call half @llvm.vector.reduce.fadd.v8f16(half %start, <8 x half> %a)
  ret half %res
@@ -122,10 +109,11 @@ define half @fadda_v128f16(half %start, <128 x half>* %a) vscale_range(16,0) #0
define float @fadda_v2f32(float %start, <2 x float> %a) vscale_range(1,0) #0 {
; CHECK-LABEL: fadda_v2f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
; CHECK-NEXT:    mov s2, v1.s[1]
; CHECK-NEXT:    fadd s0, s0, s1
; CHECK-NEXT:    fadd s0, s0, s2
; CHECK-NEXT:    // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT:    ptrue p0.s, vl2
; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT:    fadda s0, p0, s0, z1.s
; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT:    ret
  %res = call float @llvm.vector.reduce.fadd.v2f32(float %start, <2 x float> %a)
  ret float %res
@@ -135,13 +123,11 @@ define float @fadda_v2f32(float %start, <2 x float> %a) vscale_range(1,0) #0 {
define float @fadda_v4f32(float %start, <4 x float> %a) vscale_range(1,0) #0 {
; CHECK-LABEL: fadda_v4f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    mov s2, v1.s[1]
; CHECK-NEXT:    fadd s0, s0, s1
; CHECK-NEXT:    mov s3, v1.s[2]
; CHECK-NEXT:    mov s1, v1.s[3]
; CHECK-NEXT:    fadd s0, s0, s2
; CHECK-NEXT:    fadd s0, s0, s3
; CHECK-NEXT:    fadd s0, s0, s1
; CHECK-NEXT:    // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT:    ptrue p0.s, vl4
; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT:    fadda s0, p0, s0, z1.s
; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT:    ret
  %res = call float @llvm.vector.reduce.fadd.v4f32(float %start, <4 x float> %a)
  ret float %res
@@ -229,9 +215,11 @@ define double @fadda_v1f64(double %start, <1 x double> %a) vscale_range(1,0) #0
define double @fadda_v2f64(double %start, <2 x double> %a) vscale_range(1,0) #0 {
; CHECK-LABEL: fadda_v2f64:
; CHECK:       // %bb.0:
; CHECK-NEXT:    mov d2, v1.d[1]
; CHECK-NEXT:    fadd d0, d0, d1
; CHECK-NEXT:    fadd d0, d0, d2
; CHECK-NEXT:    // kill: def $d0 killed $d0 def $z0
; CHECK-NEXT:    ptrue p0.d, vl2
; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT:    fadda d0, p0, d0, z1.d
; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT:    ret
  %res = call double @llvm.vector.reduce.fadd.v2f64(double %start, <2 x double> %a)
  ret double %res