Unverified Commit 2c39855f authored by David Sherwood's avatar David Sherwood Committed by GitHub
Browse files

[AArch64] Sanitise pow inputs using a target DAG combine (#192958)

Sometimes we see LLVM IR like this:

  %pow = call fast <4 x float> @llvm.pow.v4f32(...)
  %fcmp = fcmp fast ...
  %res = select <4 x i1> %fcmp, <4 x float> %val, <4 x float> %pow

where the pow intrinsic is called unconditionally, but only certain
lanes of the result are used. In fact, LLVM actively encourages code
like this due to the intrinsic being marked as safe to speculatively
execute. However, we know when using certain vector libraries like
ArmPL that this can be very costly if the unused lanes would take
the pow call down an expensive path. For example, if an input to
pow is a special value (inf, NaN, -0) then it triggers slow special
case handling, and ultimately the result is going to be ignored
anyway. For this reason we prefer to sanitise the pow input to
use 'safe' values when we know the result is going to be discarded.
The above example LLVM IR would then look like

  %fcmp = fcmp fast ...
  %sel = select <4 x i1>, <4 x float> splat(float 1.0), ...
  %pow = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %sel, ...)
  %res = select <4 x i1> %fcmp, <4 x float> %val, <4 x float> %pow

where the value 1.0 is chosen due to the fact pow is known to always
return 1.0 for all powers.
parent dc19e4b0
Loading
Loading
Loading
Loading
+55 −0
Original line number Diff line number Diff line
@@ -27773,6 +27773,58 @@ static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
                     {InverseSetCC, SelectB, SelectA});
}
static SDValue performVselectPowCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI) {
  assert(N->getOpcode() == ISD::VSELECT && "Expected VSELECT opcode");
  SDValue Cond = N->getOperand(0);
  SDValue TrueVal = N->getOperand(1);
  SDValue FalseVal = N->getOperand(2);
  bool TrueValIsPow = TrueVal.getOpcode() == ISD::FPOW;
  bool FalseValIsPow = FalseVal.getOpcode() == ISD::FPOW;
  // If both inputs are pow we could equally remove the select and simply
  // select between pow inputs instead.
  if (TrueValIsPow == FalseValIsPow)
    return SDValue();
  if ((TrueValIsPow && !TrueVal.hasOneUse()) ||
      (FalseValIsPow && !FalseVal.hasOneUse()))
    return SDValue();
  EVT VT = N->getValueType(0);
  RTLIB::Libcall LC = RTLIB::getPOW(VT);
  SelectionDAG &DAG = DCI.DAG;
  auto &TLI = DAG.getTargetLoweringInfo();
  bool HasLibCall =
      TLI.getLibcallLoweringInfo().getLibcallImpl(LC) != RTLIB::Unsupported;
  if (!HasLibCall)
    return SDValue();
  SDValue OldPow = TrueValIsPow ? TrueVal : FalseVal;
  SDValue OldPowArg0 = OldPow->getOperand(0);
  // Bail out if argument 0 is already a select, in order to avoid an infinite
  // combine loop.
  if (OldPowArg0.getOpcode() == ISD::VSELECT)
    return SDValue();
  // For a given call pow(x, y) when x=1.0 it is guaranteed to return 1.0 for
  // any value of y.
  SDLoc DL(N);
  SDValue SplatOne = DAG.getConstantFP(1.0, DL, VT);
  SDValue NewPowArg0;
  if (TrueValIsPow)
    NewPowArg0 = DAG.getNode(ISD::VSELECT, DL, VT, Cond, OldPowArg0, SplatOne);
  else
    NewPowArg0 = DAG.getNode(ISD::VSELECT, DL, VT, Cond, SplatOne, OldPowArg0);
  SDValue NewPow = DAG.getNode(ISD::FPOW, DL, VT, NewPowArg0,
                               OldPow->getOperand(1), OldPow->getFlags());
  if (TrueValIsPow)
    return DAG.getNode(ISD::VSELECT, DL, VT, Cond, NewPow, FalseVal);
  return DAG.getNode(ISD::VSELECT, DL, VT, Cond, TrueVal, NewPow);
}
// vselect (v1i1 setcc) ->
//     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
// FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
@@ -27860,6 +27912,9 @@ static SDValue performVSelectCombine(SDNode *N,
    }
  }
  if (SDValue R = performVselectPowCombine(N, DCI))
    return R;
  EVT CmpVT = N0.getOperand(0).getValueType();
  if (N0.getOpcode() != ISD::SETCC ||
      CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
+228 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -mattr=+neon,+sve -mtriple=aarch64-linux-gnu --vector-library=ArmPL < %s | FileCheck %s

define <4 x float> @select_false_is_pow_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c, <4 x float> %d) nounwind {
; CHECK-LABEL: select_false_is_pow_v4f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT:    mov v16.16b, v2.16b
; CHECK-NEXT:    fmov v2.4s, #5.00000000
; CHECK-NEXT:    fcmge v17.4s, v2.4s, v16.4s
; CHECK-NEXT:    fmov v2.4s, #1.00000000
; CHECK-NEXT:    bit v0.16b, v2.16b, v17.16b
; CHECK-NEXT:    bl armpl_vpowq_f32
; CHECK-NEXT:    bit v0.16b, v16.16b, v17.16b
; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %a, <4 x float> %b)
  %fcmp = fcmp fast ole <4 x float> %c, splat (float 5.0e+0)
  %res = select <4 x i1> %fcmp, <4 x float> %c, <4 x float> %pow
  ret <4 x float> %res
}

define <2 x double> @select_false_is_pow_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c, <2 x double> %d) nounwind {
; CHECK-LABEL: select_false_is_pow_v2f64:
; CHECK:       // %bb.0:
; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT:    mov v16.16b, v3.16b
; CHECK-NEXT:    fmov v3.2d, #5.00000000
; CHECK-NEXT:    fcmge v17.2d, v3.2d, v2.2d
; CHECK-NEXT:    fmov v2.2d, #1.00000000
; CHECK-NEXT:    bit v0.16b, v2.16b, v17.16b
; CHECK-NEXT:    bl armpl_vpowq_f64
; CHECK-NEXT:    bit v0.16b, v16.16b, v17.16b
; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %a, <2 x double> %b)
  %fcmp = fcmp fast ole <2 x double> %c, splat (double 5.0e+0)
  %res = select <2 x i1> %fcmp, <2 x double> %d, <2 x double> %pow
  ret <2 x double> %res
}

; TODO: For scalable vectors we should really just be able to pass in the
; mask when lowering FPOW to a call instruction.
define <vscale x 4 x float> @select_false_is_pow_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c, <vscale x 4 x float> %d) nounwind {
; CHECK-LABEL: select_false_is_pow_nxv4f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT:    addvl sp, sp, #-2
; CHECK-NEXT:    fmov z3.s, #5.00000000
; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Spill
; CHECK-NEXT:    ptrue p0.s
; CHECK-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
; CHECK-NEXT:    mov z8.d, z2.d
; CHECK-NEXT:    fcmge p4.s, p0/z, z3.s, z2.s
; CHECK-NEXT:    fmov z0.s, p4/m, #1.00000000
; CHECK-NEXT:    bl armpl_svpow_f32_x
; CHECK-NEXT:    mov z0.s, p4/m, z8.s
; CHECK-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Reload
; CHECK-NEXT:    addvl sp, sp, #2
; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b)
  %fcmp = fcmp fast ole <vscale x 4 x float> %c, splat (float 5.0e+0)
  %res = select <vscale x 4 x i1> %fcmp, <vscale x 4 x float> %c, <vscale x 4 x float> %pow
  ret <vscale x 4 x float> %res
}

define <vscale x 2 x double> @select_false_is_pow_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c, <vscale x 2 x double> %d) nounwind {
; CHECK-LABEL: select_false_is_pow_nxv2f64:
; CHECK:       // %bb.0:
; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT:    addvl sp, sp, #-2
; CHECK-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
; CHECK-NEXT:    mov z8.d, z3.d
; CHECK-NEXT:    fmov z3.d, #5.00000000
; CHECK-NEXT:    ptrue p0.d
; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Spill
; CHECK-NEXT:    fcmge p4.d, p0/z, z3.d, z2.d
; CHECK-NEXT:    fmov z0.d, p4/m, #1.00000000
; CHECK-NEXT:    bl armpl_svpow_f64_x
; CHECK-NEXT:    mov z0.d, p4/m, z8.d
; CHECK-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Reload
; CHECK-NEXT:    addvl sp, sp, #2
; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b)
  %fcmp = fcmp fast ole <vscale x 2 x double> %c, splat (double 5.0e+0)
  %res = select <vscale x 2 x i1> %fcmp, <vscale x 2 x double> %d, <vscale x 2 x double> %pow
  ret <vscale x 2 x double> %res
}

; Only added one variant for the true value case, since the other VFs are tested above.
define <4 x float> @select_true_is_pow_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c, <4 x float> %d) nounwind {
; CHECK-LABEL: select_true_is_pow_v4f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT:    mov v16.16b, v2.16b
; CHECK-NEXT:    fmov v2.4s, #5.00000000
; CHECK-NEXT:    fcmge v17.4s, v2.4s, v16.4s
; CHECK-NEXT:    fmov v2.4s, #1.00000000
; CHECK-NEXT:    bif v0.16b, v2.16b, v17.16b
; CHECK-NEXT:    bl armpl_vpowq_f32
; CHECK-NEXT:    bif v0.16b, v16.16b, v17.16b
; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %a, <4 x float> %b)
  %fcmp = fcmp fast ole <4 x float> %c, splat (float 5.0e+0)
  %res = select <4 x i1> %fcmp, <4 x float> %pow, <4 x float> %c
  ret <4 x float> %res
}

; Negative tests

; We don't have a vector library pow function for half types
define <8 x half> @select_true_is_pow_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, <8 x half> %d) nounwind {
; CHECK-LABEL: select_true_is_pow_v8f16:
; CHECK:       // %bb.0:
; CHECK-NEXT:    sub sp, sp, #80
; CHECK-NEXT:    stp q2, q0, [sp] // 32-byte Folded Spill
; CHECK-NEXT:    mov h3, v0.h[1]
; CHECK-NEXT:    mov h2, v1.h[1]
; CHECK-NEXT:    str q1, [sp, #32] // 16-byte Spill
; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Spill
; CHECK-NEXT:    fcvt s0, h3
; CHECK-NEXT:    fcvt s1, h2
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    str q0, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldp q1, q3, [sp, #32] // 32-byte Folded Reload
; CHECK-NEXT:    mov h1, v1.h[2]
; CHECK-NEXT:    mov v0.h[1], v3.h[0]
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    str q0, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldr q0, [sp, #16] // 16-byte Reload
; CHECK-NEXT:    mov h0, v0.h[2]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q1, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    mov v1.h[2], v0.h[0]
; CHECK-NEXT:    str q1, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    mov h0, v0.h[3]
; CHECK-NEXT:    mov h1, v1.h[3]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q1, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    mov v1.h[3], v0.h[0]
; CHECK-NEXT:    str q1, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    mov h0, v0.h[4]
; CHECK-NEXT:    mov h1, v1.h[4]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q1, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    mov v1.h[4], v0.h[0]
; CHECK-NEXT:    str q1, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    mov h0, v0.h[5]
; CHECK-NEXT:    mov h1, v1.h[5]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q1, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    mov v1.h[5], v0.h[0]
; CHECK-NEXT:    str q1, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    mov h0, v0.h[6]
; CHECK-NEXT:    mov h1, v1.h[6]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q1, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    mov v1.h[6], v0.h[0]
; CHECK-NEXT:    str q1, [sp, #48] // 16-byte Spill
; CHECK-NEXT:    ldp q0, q1, [sp, #16] // 32-byte Folded Reload
; CHECK-NEXT:    mov h0, v0.h[7]
; CHECK-NEXT:    mov h1, v1.h[7]
; CHECK-NEXT:    fcvt s0, h0
; CHECK-NEXT:    fcvt s1, h1
; CHECK-NEXT:    bl powf
; CHECK-NEXT:    ldr q2, [sp] // 16-byte Reload
; CHECK-NEXT:    fcvt h0, s0
; CHECK-NEXT:    ldr q3, [sp, #48] // 16-byte Reload
; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Reload
; CHECK-NEXT:    fcmle v1.8h, v2.8h, #0.0
; CHECK-NEXT:    mov v3.h[7], v0.h[0]
; CHECK-NEXT:    mov v0.16b, v1.16b
; CHECK-NEXT:    bsl v0.16b, v3.16b, v2.16b
; CHECK-NEXT:    add sp, sp, #80
; CHECK-NEXT:    ret
  %pow = call fast <8 x half> @llvm.pow.v8f16(<8 x half> %a, <8 x half> %b)
  %fcmp = fcmp fast ole <8 x half> %c, zeroinitializer
  %res = select <8 x i1> %fcmp, <8 x half> %pow, <8 x half> %c
  ret <8 x half> %res
}

define <4 x float> @select_pow_mul_use_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c, <4 x float> %d) nounwind {
; CHECK-LABEL: select_pow_mul_use_v4f32:
; CHECK:       // %bb.0:
; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT:    mov v16.16b, v2.16b
; CHECK-NEXT:    bl armpl_vpowq_f32
; CHECK-NEXT:    fmov v1.4s, #5.00000000
; CHECK-NEXT:    fcmge v1.4s, v1.4s, v16.4s
; CHECK-NEXT:    bsl v1.16b, v16.16b, v0.16b
; CHECK-NEXT:    fadd v0.4s, v1.4s, v0.4s
; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT:    ret
  %pow = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %a, <4 x float> %b)
  %fcmp = fcmp fast ole <4 x float> %c, splat (float 5.0e+0)
  %sel = select <4 x i1> %fcmp, <4 x float> %c, <4 x float> %pow
  %res = fadd <4 x float> %sel, %pow
  ret <4 x float> %res
}