Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +2 −2 Original line number Diff line number Diff line Loading @@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading Loading @@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +2 −2 Original line number Diff line number Diff line Loading @@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading Loading @@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading