Commit 57437cb1 authored by Rewon Child's avatar Rewon Child
Browse files

Fix syntax

parent 0aff3629
Loading
Loading
Loading
Loading
+3 −8
Original line number Diff line number Diff line
@@ -129,18 +129,13 @@ def count_zeros_fp32(parameters):
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    grads_to_count = []
    total_num_zeros = 0.0
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grad = param.grad.detach()
            grads_to_count.append(grad)

    total_num_zeros = 0.0

    for grad in grads_to_count:
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros