Commit 2ef23675 authored by Jared Casper's avatar Jared Casper Committed by Mohammad Shoeybi
Browse files

Support latest PyTorch RNG state API. (#8)

Fixes #7.
parent a0368ddf
Loading
Loading
Loading
Loading
+20 −3
Original line number Diff line number Diff line
@@ -41,9 +41,26 @@ def _set_cuda_rng_state(new_state, device=-1):
    with a single change: the input state is not cloned. Cloning caused
    major performance issues for +4 GPU cases.
    """
    if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
        # older PyTorch
        def cb():
            with device_ctx_manager(device):
                _C._cuda_setRNGState(new_state)
    else:
        # newer PyTorch
        if device == -1:
            device = torch.device('cuda')
        elif isinstance(device, str):
            device = torch.device(device)
        elif isinstance(device, int):
            device = torch.device('cuda', device)

        def cb():
            idx = device.index
            if idx is None:
                idx = torch.cuda.current_device()
            default_generator = torch.cuda.default_generators[idx]
            default_generator.set_state(new_state)

    _lazy_call(cb)