Unverified Commit 0a600c34 authored by Guray Ozen's avatar Guray Ozen Committed by GitHub
Browse files

[mlir][nvgpu] Make `phaseParity` of `mbarrier.try_wait` `i1` (#81460)

Currently, `phaseParity` argument of `nvgpu.mbarrier.try_wait.parity` is
index. This can cause a problem if it's passed any value different than
0 or 1. Because the PTX instruction only accepts even or odd phase. This
PR makes phaseParity argument i1 to avoid misuse.

Here is the information from PTX doc:

```
The .parity variant of the instructions test for the completion of the phase indicated 
by the operand phaseParity, which is the integer parity of either the current phase or 
the immediately preceding phase of the mbarrier object. An even phase has integer 
parity 0 and an odd phase has integer parity of 1. So the valid values of phaseParity 
operand are 0 and 1.
```
See for more information:

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
parent 05ad0d46
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -609,14 +609,16 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
    phase. Suspended thread resumes execution when the specified phase completes 
    OR before the phase completes following a system-dependent time limit. 

    The `$phaseParity` specifies either even phase (0) or odd phase (1) to 
    wait.

    Example:
    ```mlir
      nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
      nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
    ```

  }];
  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$phase, Index:$ticks, Index:$mbarId);
  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";  
  let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";  
}

def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {
+2 −1
Original line number Diff line number Diff line
@@ -956,7 +956,8 @@ struct NVGPUMBarrierTryWaitParityLowering
        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                       adaptor.getMbarId(), rewriter);
    Value ticks = truncToI32(b, adaptor.getTicks());
    Value phase = truncToI32(b, adaptor.getPhase());
    Value phase =
        b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());

    if (isMbarrierShared(op.getBarriers().getType())) {
      rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
+2 −1
Original line number Diff line number Diff line
@@ -1010,7 +1010,8 @@ void HopperBuilder::buildBarrierArriveTx(

void HopperBuilder::buildTryWaitParity(
    TypedValue<nvgpu::MBarrierGroupType> barrier) {
  Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
  Type i1 = rewriter.getI1Type();
  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
  // 10M is an arbitrary, not too small or too big number to specify the number
  // of ticks before retry.
  // TODO: hoist this in a default dialect constant.
+4 −4
Original line number Diff line number Diff line
@@ -590,12 +590,12 @@ func.func @mbarrier_txcount() {
    }
      

    %phase = arith.constant 0 : index
    %phase_c0 = arith.constant 0 : i1
    %ticks = arith.constant 10000000 : index
    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
    // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

    func.return 
}
@@ -626,12 +626,12 @@ func.func @mbarrier_txcount_pred() {
    // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
    nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType

    %phase = arith.constant 0 : index
    %phase_c0 = arith.constant 0 : i1
    %ticks = arith.constant 10000000 : index
    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
    // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

    func.return 
}
+1 −1
Original line number Diff line number Diff line
@@ -62,7 +62,7 @@ func.func @main() {
    //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
    //      CHECK: }
    //
    //      CHECK: %[[c0_6:.*]] = arith.constant 0 : index
    //      CHECK: %[[c0_6:.*]] = llvm.mlir.constant(false) : i1 
    //      CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
    //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>

Loading