Loading megatron/training.py +2 −0 Original line number Diff line number Diff line Loading @@ -347,6 +347,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() # Temporary workaround for batch_isend_irecv() race condition. torch.cuda.synchronize() return tensor_recv_prev, tensor_recv_next Loading Loading
megatron/training.py +2 −0 Original line number Diff line number Diff line Loading @@ -347,6 +347,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() # Temporary workaround for batch_isend_irecv() race condition. torch.cuda.synchronize() return tensor_recv_prev, tensor_recv_next Loading