Loading megatron/fused_kernels/scaled_masked_softmax.h +2 −2 Original line number Diff line number Diff line Loading @@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading Loading @@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading Loading
megatron/fused_kernels/scaled_masked_softmax.h +2 −2 Original line number Diff line number Diff line Loading @@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading Loading @@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading