Commit 82945c1a authored by mohammad's avatar mohammad
Browse files

fixed assert issue in clip grads for none grads

parent 571f10a0
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -53,14 +53,14 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
    grads = []
    grads_for_norm = []
    for param in parameters:
        # Make sure the grads are in fp32
        assert param.grad.type() == 'torch.cuda.FloatTensor'
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, 'shared') or not param.shared
        is_not_tp_duplicate = param.tensor_model_parallel or \
                              (mpu.get_tensor_model_parallel_rank() == 0)
        grad = param.grad.detach()
        if grad_not_none:
            # Make sure the grads are in fp32
            assert param.grad.type() == 'torch.cuda.FloatTensor'
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)