Commit 531152d9 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent b1a83375
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

@@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward(
    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
    input_t temp_grad[ELEMENTS_PER_LDG_STG];
    input_t temp_output[ELEMENTS_PER_LDG_STG];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+8 −7
Original line number Diff line number Diff line
@@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = 0; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);

template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }


int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
@@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
            } else if (element_index < element_count) {
                copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE)
                copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
            } else {
                break;
            }