Unverified Commit 4263b2ec authored by peterbell10's avatar peterbell10 Committed by GitHub
Browse files

[NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672)

In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given
`load<8 x half>` followed by `fpext to <8 x float>` we get

```
ld.shared.v4.b16 	{%f1, %f2, %f3, %f4}, [%r15+8];
ld.shared.v4.b16 	{%f5, %f6, %f7, %f8}, [%r15];
```

Which loads float16 values into float registers without any conversion
and the result is simply garbage.

This PR brings `v8f16` and `v8bf16` into line with the other vector
types by expanding it to load + cvt.

cc @manman-ren @Artem-B @jlebar
parent bfbfd1ca
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -606,6 +606,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
  setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
  setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
  // Turn FP truncstore into trunc + store.
  // FIXME: vector types should also be expanded
  setTruncStoreAction(MVT::f32, MVT::f16, Expand);
+20 −0
Original line number Diff line number Diff line
@@ -207,3 +207,23 @@ define bfloat @test_select_cc_bf16_f64(double %a, double %b, bfloat %c, bfloat %
  %r = select i1 %cc, bfloat %c, bfloat %d
  ret bfloat %r
}

; CHECK-LABEL: test_extload_bf16x8
; CHECK: ld.shared.v4.b32 {%r
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
; SM80: cvt.f32.bf16 %f{{.*}}, %rs
define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
  %load = load <8 x bfloat>, ptr addrspace(3) %arg, align 16
  %res = fpext <8 x bfloat> %load to <8 x float>
  ret <8 x float> %res
}
+58 −12
Original line number Diff line number Diff line
@@ -99,9 +99,20 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177

; CHECK-LABEL: extv8f16_global_a16(
define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
; CHECK: ld.global.v4.b16 {%f
; CHECK: ld.global.v4.b16 {%f
; CHECK: ld.global.v4.b32 {%r
  %v = load <8 x half>, ptr addrspace(1) %src, align 16
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
  %ext = fpext <8 x half> %v to <8 x float>
; CHECK: st.global.v4.f32
; CHECK: st.global.v4.f32
@@ -111,11 +122,23 @@ define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst

; CHECK-LABEL: extv8f16_global_a4(
define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
; CHECK: ld.global.v2.b16 {%f
; CHECK: ld.global.v2.b16 {%f
; CHECK: ld.global.v2.b16 {%f
; CHECK: ld.global.v2.b16 {%f
; CHECK: ld.global.b32 %r
; CHECK: ld.global.b32 %r
; CHECK: ld.global.b32 %r
; CHECK: ld.global.b32 %r
  %v = load <8 x half>, ptr addrspace(1) %src, align 4
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
  %ext = fpext <8 x half> %v to <8 x float>
; CHECK: st.global.v4.f32
; CHECK: st.global.v4.f32
@@ -126,9 +149,20 @@ define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst,

; CHECK-LABEL: extv8f16_generic_a16(
define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
; CHECK: ld.v4.b16 {%f
; CHECK: ld.v4.b16 {%f
; CHECK: ld.v4.b32 {%r
  %v = load <8 x half>, ptr %src, align 16
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
  %ext = fpext <8 x half> %v to <8 x float>
; CHECK: st.v4.f32
; CHECK: st.v4.f32
@@ -138,11 +172,23 @@ define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalia

; CHECK-LABEL: extv8f16_generic_a4(
define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
; CHECK: ld.v2.b16 {%f
; CHECK: ld.v2.b16 {%f
; CHECK: ld.v2.b16 {%f
; CHECK: ld.v2.b16 {%f
; CHECK: ld.b32 %r
; CHECK: ld.b32 %r
; CHECK: ld.b32 %r
; CHECK: ld.b32 %r
  %v = load <8 x half>, ptr %src, align 4
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: mov.b32 {%rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
; CHECK: cvt.f32.f16 %f{{.*}}, %rs
  %ext = fpext <8 x half> %v to <8 x float>
; CHECK: st.v4.f32
; CHECK: st.v4.f32