Commit 1b3dfa2f authored by Deepak Narayanan's avatar Deepak Narayanan Committed by Jared Casper
Browse files

Use torch.cuda.synchronize() right after calling batch_isend_irecv() communication API

parent be473a5b
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -347,6 +347,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()
    # Temporary workaround for batch_isend_irecv() race condition.
    torch.cuda.synchronize()

    return tensor_recv_prev, tensor_recv_next