Commit cd499559 authored by Jared Casper's avatar Jared Casper
Browse files

Revert incorrect fix.

parent 2a34e0ec
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -84,7 +84,7 @@ def average_losses_across_data_parallel_group(losses):
        [loss.clone().detach().view(1) for loss in losses])
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses /
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())

    return averaged_losses