Unverified Commit 3983bf64 authored by Joseph Huber's avatar Joseph Huber Committed by GitHub
Browse files

[AMDGPU] Optimize DPP for fmin/fmax functions (#195282)

Summary:
These functions currently don't simplify in the optimistic (no-NaN) case
as their identity is not recognized by the optimizer. This PR simply
adds the -inf,+inf checks so these combine without the intermediate
moves.
parent be11e2b3
Loading
Loading
Loading
Loading
+38 −0
Original line number Diff line number Diff line
@@ -489,6 +489,44 @@ static bool isIdentityValue(unsigned OrigMIOp, MachineOperand *OldOpnd) {
    if (OldOpnd->getImm() == 1)
      return true;
    break;
  case AMDGPU::V_MIN_F32_e32:
  case AMDGPU::V_MIN_F32_e64:
    if (static_cast<uint32_t>(OldOpnd->getImm()) == /*+inf=*/0x7F800000)
      return true;
    break;
  case AMDGPU::V_MAX_F32_e32:
  case AMDGPU::V_MAX_F32_e64:
    if (static_cast<uint32_t>(OldOpnd->getImm()) == /*-inf=*/0xFF800000)
      return true;
    break;
  case AMDGPU::V_MIN_F64_e64:
  case AMDGPU::V_MIN_NUM_F64_e64:
    if (static_cast<uint64_t>(OldOpnd->getImm()) == /*+inf=*/0x7FF0000000000000)
      return true;
    break;
  case AMDGPU::V_MAX_F64_e64:
  case AMDGPU::V_MAX_NUM_F64_e64:
    if (static_cast<uint64_t>(OldOpnd->getImm()) == /*-inf=*/0xFFF0000000000000)
      return true;
    break;
  case AMDGPU::V_MIN_F16_e32:
  case AMDGPU::V_MIN_F16_e64:
  case AMDGPU::V_MIN_F16_t16_e32:
  case AMDGPU::V_MIN_F16_t16_e64:
  case AMDGPU::V_MIN_F16_fake16_e32:
  case AMDGPU::V_MIN_F16_fake16_e64:
    if (static_cast<uint16_t>(OldOpnd->getImm()) == /*+inf=*/0x7C00)
      return true;
    break;
  case AMDGPU::V_MAX_F16_e32:
  case AMDGPU::V_MAX_F16_e64:
  case AMDGPU::V_MAX_F16_t16_e32:
  case AMDGPU::V_MAX_F16_t16_e64:
  case AMDGPU::V_MAX_F16_fake16_e32:
  case AMDGPU::V_MAX_F16_fake16_e64:
    if (static_cast<uint16_t>(OldOpnd->getImm()) == /*-inf=*/0xFC00)
      return true;
    break;
  }
  return false;
}
+22 −0
Original line number Diff line number Diff line
@@ -113,6 +113,28 @@ bb1:
  br label %bb1
}

; GCN-LABEL: {{^}}dpp64_fmin:
; DPP64-GFX1251: v_min_num_f64_dpp v[0:1], v[0:1], v[0:1] [[CTL]]:1 row_mask:0xf bank_mask:0xf{{$}}
; DPP64-GFX9: v_min_f64 v[0:1], v[0:1], v[{{[0-9:]+}}]{{$}}
; DPP32: v_min{{(_num)?}}_f64{{(_e32)?}} v[0:1], v[0:1], v[{{[0-9:]+}}]{{$}}
define nofpclass(nan) double @dpp64_fmin(double nofpclass(nan) %x) {
entry:
  %dpp = tail call double @llvm.amdgcn.update.dpp.f64(double 0x7FF0000000000000, double %x, i32 337, i32 15, i32 15, i1 false)
  %min = tail call nnan double @llvm.minnum.f64(double %x, double %dpp)
  ret double %min
}

; GCN-LABEL: {{^}}dpp64_fmax:
; DPP64-GFX1251: v_max_num_f64_dpp v[0:1], v[0:1], v[0:1] [[CTL]]:1 row_mask:0xf bank_mask:0xf{{$}}
; DPP64-GFX9: v_max_f64 v[0:1], v[0:1], v[{{[0-9:]+}}]{{$}}
; DPP32: v_max{{(_num)?}}_f64{{(_e32)?}} v[0:1], v[0:1], v[{{[0-9:]+}}]{{$}}
define nofpclass(nan) double @dpp64_fmax(double nofpclass(nan) %x) #0 {
entry:
  %dpp = tail call double @llvm.amdgcn.update.dpp.f64(double 0xFFF0000000000000, double %x, i32 337, i32 15, i32 15, i1 false)
  %max = tail call nnan double @llvm.maxnum.f64(double %x, double %dpp)
  ret double %max
}

declare i32 @llvm.amdgcn.workitem.id.x()
declare i64 @llvm.amdgcn.update.dpp.i64(i64, i64, i32, i32, i32, i1) #0
declare double @llvm.ceil.f64(double)
+120 −0
Original line number Diff line number Diff line
@@ -93,6 +93,126 @@ define amdgpu_kernel void @dpp_fadd_f16(ptr addrspace(1) %arg) {
  ret void
}

; GCN-LABEL: {{^}}dpp_fmin_f32:
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) float @dpp_fmin_f32(float nofpclass(nan) %x) {
entry:
  %dpp.shr1 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %x, i32 273, i32 15, i32 15, i1 false)
  %min1 = tail call nnan float @llvm.minnum.f32(float %x, float %dpp.shr1)
  %dpp.shr2 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min1, i32 274, i32 15, i32 15, i1 false)
  %min2 = tail call nnan float @llvm.minnum.f32(float %min1, float %dpp.shr2)
  %dpp.shr4 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min2, i32 276, i32 15, i32 15, i1 false)
  %min3 = tail call nnan float @llvm.minnum.f32(float %min2, float %dpp.shr4)
  %dpp.shr8 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min3, i32 280, i32 15, i32 15, i1 false)
  %min4 = tail call nnan float @llvm.minnum.f32(float %min3, float %dpp.shr8)
  ret float %min4
}

; GCN-LABEL: {{^}}dpp_fmax_f32:
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) float @dpp_fmax_f32(float nofpclass(nan) %x) #0 {
entry:
  %dpp.shr1 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %x, i32 273, i32 15, i32 15, i1 false)
  %max1 = tail call nnan float @llvm.maxnum.f32(float %x, float %dpp.shr1)
  %dpp.shr2 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max1, i32 274, i32 15, i32 15, i1 false)
  %max2 = tail call nnan float @llvm.maxnum.f32(float %max1, float %dpp.shr2)
  %dpp.shr4 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max2, i32 276, i32 15, i32 15, i1 false)
  %max3 = tail call nnan float @llvm.maxnum.f32(float %max2, float %dpp.shr4)
  %dpp.shr8 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max3, i32 280, i32 15, i32 15, i1 false)
  %max4 = tail call nnan float @llvm.maxnum.f32(float %max3, float %dpp.shr8)
  ret float %max4
}

; GCN-LABEL: {{^}}dpp_fminimum_f32:
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_min{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) float @dpp_fminimum_f32(float nofpclass(nan) %x) {
entry:
  %dpp.shr1 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %x, i32 273, i32 15, i32 15, i1 false)
  %min1 = tail call nnan float @llvm.minimumnum.f32(float %x, float %dpp.shr1)
  %dpp.shr2 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min1, i32 274, i32 15, i32 15, i1 false)
  %min2 = tail call nnan float @llvm.minimumnum.f32(float %min1, float %dpp.shr2)
  %dpp.shr4 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min2, i32 276, i32 15, i32 15, i1 false)
  %min3 = tail call nnan float @llvm.minimumnum.f32(float %min2, float %dpp.shr4)
  %dpp.shr8 = tail call float @llvm.amdgcn.update.dpp.f32(float 0x7FF0000000000000, float %min3, i32 280, i32 15, i32 15, i1 false)
  %min4 = tail call nnan float @llvm.minimumnum.f32(float %min3, float %dpp.shr8)
  ret float %min4
}

; GCN-LABEL: {{^}}dpp_fmaximum_f32:
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GCN: v_max{{(_num)?}}_f32_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) float @dpp_fmaximum_f32(float nofpclass(nan) %x) #0 {
entry:
  %dpp.shr1 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %x, i32 273, i32 15, i32 15, i1 false)
  %max1 = tail call nnan float @llvm.maximumnum.f32(float %x, float %dpp.shr1)
  %dpp.shr2 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max1, i32 274, i32 15, i32 15, i1 false)
  %max2 = tail call nnan float @llvm.maximumnum.f32(float %max1, float %dpp.shr2)
  %dpp.shr4 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max2, i32 276, i32 15, i32 15, i1 false)
  %max3 = tail call nnan float @llvm.maximumnum.f32(float %max2, float %dpp.shr4)
  %dpp.shr8 = tail call float @llvm.amdgcn.update.dpp.f32(float 0xFFF0000000000000, float %max3, i32 280, i32 15, i32 15, i1 false)
  %max4 = tail call nnan float @llvm.maximumnum.f32(float %max3, float %dpp.shr8)
  ret float %max4
}

; GCN-LABEL: {{^}}dpp_fmin_f16:
; GFX9GFX10: v_min_f16_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_min_f16_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_min_f16_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_min_f16_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-TRUE16: v_mov_b32_dpp {{v[0-9]+}}, {{v[0-9]+}} row_shr:1 row_mask:0xf bank_mask:0xf
; GFX11-TRUE16: v_min{{(_num)?}}_f16_e32
; GFX11-FAKE16: v_min{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_min{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_min{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_min{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) half @dpp_fmin_f16(half nofpclass(nan) %x) {
entry:
  %dpp.shr1 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xH7C00, half %x, i32 273, i32 15, i32 15, i1 false)
  %min1 = tail call nnan half @llvm.minnum.f16(half %x, half %dpp.shr1)
  %dpp.shr2 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xH7C00, half %min1, i32 274, i32 15, i32 15, i1 false)
  %min2 = tail call nnan half @llvm.minnum.f16(half %min1, half %dpp.shr2)
  %dpp.shr4 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xH7C00, half %min2, i32 276, i32 15, i32 15, i1 false)
  %min3 = tail call nnan half @llvm.minnum.f16(half %min2, half %dpp.shr4)
  %dpp.shr8 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xH7C00, half %min3, i32 280, i32 15, i32 15, i1 false)
  %min4 = tail call nnan half @llvm.minnum.f16(half %min3, half %dpp.shr8)
  ret half %min4
}

; GCN-LABEL: {{^}}dpp_fmax_f16:
; GFX9GFX10: v_max_f16_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_max_f16_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_max_f16_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GFX9GFX10: v_max_f16_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-TRUE16: v_mov_b32_dpp {{v[0-9]+}}, {{v[0-9]+}} row_shr:1 row_mask:0xf bank_mask:0xf
; GFX11-TRUE16: v_max{{(_num)?}}_f16_e32
; GFX11-FAKE16: v_max{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:1 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_max{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:2 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_max{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:4 row_mask:0xf bank_mask:0xf{{$}}
; GFX11-FAKE16: v_max{{(_num)?}}_f16_e64_dpp v0, v0, v0 row_shr:8 row_mask:0xf bank_mask:0xf{{$}}
define nofpclass(nan) half @dpp_fmax_f16(half nofpclass(nan) %x) #0 {
entry:
  %dpp.shr1 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xHFC00, half %x, i32 273, i32 15, i32 15, i1 false)
  %max1 = tail call nnan half @llvm.maxnum.f16(half %x, half %dpp.shr1)
  %dpp.shr2 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xHFC00, half %max1, i32 274, i32 15, i32 15, i1 false)
  %max2 = tail call nnan half @llvm.maxnum.f16(half %max1, half %dpp.shr2)
  %dpp.shr4 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xHFC00, half %max2, i32 276, i32 15, i32 15, i1 false)
  %max3 = tail call nnan half @llvm.maxnum.f16(half %max2, half %dpp.shr4)
  %dpp.shr8 = tail call half @llvm.amdgcn.update.dpp.f16(half 0xHFC00, half %max3, i32 280, i32 15, i32 15, i1 false)
  %max4 = tail call nnan half @llvm.maxnum.f16(half %max3, half %dpp.shr8)
  ret half %max4
}

declare i32 @llvm.amdgcn.workitem.id.x()
declare i32 @llvm.amdgcn.update.dpp.i32(i32, i32, i32, i32, i32, i1) #0
declare float @llvm.ceil.f32(float)