Commit 08f5ff78 authored by Peder Bergebakken Sundt's avatar Peder Bergebakken Sundt Committed by SomeoneSerge
Browse files

python312Packages.pytorch3d: test cuda

parent 030fc6a5
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -46,6 +46,19 @@ buildPythonPackage rec {

  pythonImportsCheck = [ "pytorch3d" ];

  passthru.tests.rotations-cuda =
    cudaPackages.writeGpuTestPython { libraries = ps: [ ps.pytorch3d ]; }
      ''
        import pytorch3d.transforms as p3dt

        M = p3dt.random_rotations(n=10, device="cuda")
        assert "cuda" in M.device.type
        angles = p3dt.matrix_to_euler_angles(M, "XYZ")
        assert "cuda" in angles.device.type
        assert angles.shape == (10, 3), angles.shape
        print(angles)
      '';

  meta = {
    description = "FAIR's library of reusable components for deep learning with 3D data";
    homepage = "https://github.com/facebookresearch/pytorch3d";