Unverified Commit 4fbd4333 authored by Someone's avatar Someone Committed by GitHub
Browse files

Merge pull request #325222 from SomeoneSerge/fix/gpu-access/torch-bin

python3Packages.torch-bin: gpuChecks -> tests.tester-<name>.gpuCheck
parents a4f437f0 3eff2015
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -121,7 +121,10 @@ buildPythonPackage {

  pythonImportsCheck = [ "torch" ];

  passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
  passthru.tests = callPackage ./tests.nix {
    torchWithCuda = torch-bin;
    torchWithRocm = torch-bin;
  };

  meta = {
    description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
+0 −40
Original line number Diff line number Diff line
{
  lib,
  torchWithCuda,
  torchWithRocm,
  callPackage,
}:

let
  accelAvailable =
    {
      feature,
      versionAttr,
      torch,
      cudaPackages,
    }:
    cudaPackages.writeGpuTestPython
      {
        inherit feature;
        libraries = [ torch ];
        name = "${feature}Available";
      }
      ''
        import torch
        message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
        assert torch.cuda.is_available() and torch.version.${versionAttr}, message
        print(message)
      '';
in
{
  tester-cudaAvailable = callPackage accelAvailable {
    feature = "cuda";
    versionAttr = "cuda";
    torch = torchWithCuda;
  };
  tester-rocmAvailable = callPackage accelAvailable {
    feature = "rocm";
    versionAttr = "hip";
    torch = torchWithRocm;
  };
}
+19 −0
Original line number Diff line number Diff line
{
  cudaPackages,
  feature,
  torch,
  versionAttr,
}:

cudaPackages.writeGpuTestPython
  {
    inherit feature;
    libraries = [ torch ];
    name = "${feature}Available";
  }
  ''
    import torch
    message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
    assert torch.cuda.is_available() and torch.version.${versionAttr}, message
    print(message)
  ''
+20 −2
Original line number Diff line number Diff line
{ callPackage }:
{
  callPackage,
  torchWithCuda,
  torchWithRocm,
}:

callPackage ./gpu-checks.nix { }
{
  # To perform the runtime check use either
  # `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
  # `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
  tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
    feature = "cuda";
    versionAttr = "cuda";
    torch = torchWithCuda;
  };
  tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
    feature = "rocm";
    versionAttr = "hip";
    torch = torchWithRocm;
  };
}