Commit 2387ce01 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fused_kernels_simd_fix' into 'main'

fused kernels simd fix for short sequences

See merge request ADLR/megatron-lm!302
parents e269e200 62424a74
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward(
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = 4;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
@@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = 4;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
+2 −2
Original line number Diff line number Diff line
@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = 4;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
    int local_seq = blockIdx.x + 1; 
@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = 4;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
    int local_seq = blockIdx.x + 1;