Unverified Commit c0874d92 authored by Connor Baker's avatar Connor Baker Committed by GitHub
Browse files

python3Packages.warp-lang: fix CUDA build (#419750)

parents a43007ac e5d0d6ca
Loading
Loading
Loading
Loading
+228 −84
Original line number Diff line number Diff line
{
  config,
  lib,
  stdenv,
  autoAddDriverRunpath,
  buildPythonPackage,
  fetchurl,
  config,
  cudaPackages,
  fetchFromGitHub,
  fetchurl,
  jax,
  lib,
  llvmPackages,
  numpy,
  pkgsBuildHost,
  python,
  replaceVars,
  build,
  runCommand,
  setuptools,
  numpy,
  llvmPackages,
  cudaPackages,
  unittestCheckHook,
  jax,
  stdenv,
  torch,
  nix-update-script,
  warp-lang, # Self-reference to this package for passthru.tests
  writableTmpDirAsHomeHook,
  writeShellApplication,

  # Use standalone LLVM-based JIT compiler and CPU device support
  standaloneSupport ? true,
@@ -25,63 +29,69 @@
  # Build Warp with MathDx support (requires CUDA support)
  # Most linear-algebra tile operations like tile_cholesky(), tile_fft(),
  # and tile_matmul() require Warp to be built with the MathDx library.
  libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
}:

  # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
  libmathdxSupport ? cudaSupport,
}@args:
assert libmathdxSupport -> cudaSupport;
let
  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv;
  stdenv = builtins.throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv";

  version = "1.7.2.post1";

  libmathdx = stdenv.mkDerivation (finalAttrs: {
  libmathdx = effectiveStdenv.mkDerivation (finalAttrs: {
    # NOTE: The version used should match the version Warp requires:
    # https://github.com/NVIDIA/warp/blob/4ad209076ce09668b18dedc74dce0d5cf8b9e409/deps/libmathdx-deps.packman.xml
    pname = "libmathdx";
    version = "0.2.0";
    version = "0.1.2";

    outputs = [
      "out"
      "static"
    ];

    src =
      let
        inherit (stdenv.hostPlatform) system;
        selectSystem = attrs: attrs.${system} or (throw "Unsupported system: ${system}");

        suffix = selectSystem {
          x86_64-linux = "Linux-x86_64";
          aarch64-linux = "Linux-aarch64";
          x86_64-windows = "win32-x86_64";
        };

        # nix-hash --type sha256 --to-sri $(nix-prefetch-url "https://...")
        hash = selectSystem {
          x86_64-linux = "sha256-Lk+PxWFvyQGRClFdmyuo4y7HBdR7pigOhMyEzajqbmg=";
          aarch64-linux = "sha256-6tH9YH98kSvDiut9rQEU5potEpeKqma/QtrCHLxwRLo=";
          x86_64-windows = "sha256-B8qwj7UzOXEDZh2oT3ip1qW0uqtygMsyfcbhh5Dgc8U=";
        baseURL = "https://developer.download.nvidia.com/compute/cublasdx/redist/cublasdx";
        name = lib.concatStringsSep "-" [
          finalAttrs.pname
          "Linux"
          effectiveStdenv.hostPlatform.parsed.cpu.name
          finalAttrs.version
        ];
        hashes = {
          aarch64-linux = "sha256-7HEXfzxPF62q/7pdZidj4eO09u588yxcpSu/bWot/9A=";
          x86_64-linux = "sha256-MImBFv+ooRSUqdL/YEe/bJIcVBnHMCk7SLS5eSeh0cQ=";
        };
      in
      lib.mapNullable (
        hash:
        fetchurl {
        url = "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/libmathdx-${suffix}-${finalAttrs.version}.tar.gz";
        inherit hash;
      };

    unpackPhase = ''
      runHook preUnpack

      mkdir unpacked
      cd unpacked
      tar -xzf $src
      export sourceRoot=$(pwd)

      runHook postUnpack
    '';
          inherit hash name;
          url = "${baseURL}/${name}.tar.gz";
        }
      ) (hashes.${effectiveStdenv.hostPlatform.system} or null);

    dontUnpack = true;
    dontConfigure = true;
    dontBuild = true;

    # NOTE: The leading component is stripped because the 0.1.2 release is within the `libmathdx` directory.
    installPhase = ''
      runHook preInstall

      cp -rT "$sourceRoot" "$out"
      mkdir -p "$out"
      tar -xzf "$src" --strip-components=1 -C "$out"

      mkdir -p "$static"
      moveToOutput "lib/libmathdx_static.a" "$static"

      runHook postInstall
    '';

    meta = {
      description = "library used to integrate cuBLASDx and cuFFTDx into Warp";
      homepage = "https://developer.nvidia.com/cublasdx-downloads";
      sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
      license = with lib.licenses; [
        # By downloading and using the software, you agree to fully
@@ -104,7 +114,10 @@ let
        # license:
        mit
      ];
      platforms = with lib.platforms; linux ++ [ "x86_64-windows" ];
      platforms = [
        "aarch64-linux"
        "x86_64-linux"
      ];
      maintainers = with lib.maintainers; [ yzx9 ];
    };
  });
@@ -114,6 +127,13 @@ buildPythonPackage {
  inherit version;
  pyproject = true;

  # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false,
  # causing a bunch of missing math symbols (like expf) when linking against the static library
  # provided by NVCC.
  __structuredAttrs = true;

  stdenv = effectiveStdenv;

  src = fetchFromGitHub {
    owner = "NVIDIA";
    repo = "warp";
@@ -122,7 +142,7 @@ buildPythonPackage {
  };

  patches =
    lib.optionals stdenv.hostPlatform.isDarwin [
    lib.optionals effectiveStdenv.hostPlatform.isDarwin [
      (replaceVars ./darwin-libcxx.patch {
        LIBCXX_DEV = llvmPackages.libcxx.dev;
        LIBCXX_LIB = llvmPackages.libcxx;
@@ -140,22 +160,69 @@ buildPythonPackage {
    ];

  postPatch =
    lib.optionalString (!stdenv.cc.isGNU) ''
      substituteInPlace warp/build_dll.py \
        --replace-fail "g++" "${lib.getExe stdenv.cc}"
    # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults.
    # NOTE: After 1.7.2, patching will need to be updated like this:
    # https://github.com/ConnorBaker/cuda-packages/blob/2fc8ba8c37acee427a94cdd1def55c2ec701ad82/pkgs/development/python-modules/warp/default.nix#L56-L65
    lib.optionalString cudaSupport ''
      nixLog "patching $PWD/warp/build_dll.py to use our gencode flags"
      substituteInPlace "$PWD/warp/build_dll.py" \
        --replace-fail \
          'nvcc_opts = gencode_opts + [' \
          'nvcc_opts = [ ${
            lib.concatMapStringsSep ", " (gencodeString: ''"${gencodeString}"'') cudaPackages.flags.gencode
          }, '
    ''
    # Patch build_dll.py to use dynamic libraries rather than static ones.
    # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library.
    + lib.optionalString cudaSupport ''
      nixLog "patching $PWD/warp/build_dll.py to use dynamic libraries"
      substituteInPlace "$PWD/warp/build_dll.py" \
        --replace-fail \
          '-lcudart_static' \
          '-lcudart' \
        --replace-fail \
          '-lnvrtc_static' \
          '-lnvrtc' \
        --replace-fail \
          '-lnvrtc-builtins_static' \
          '-lnvrtc-builtins' \
        --replace-fail \
          '-lnvJitLink_static' \
          '-lnvJitLink' \
        --replace-fail \
          '-lmathdx_static' \
          '-lmathdx'
    ''
    + ''
      nixLog "patching $PWD/warp/build_dll.py to use our C++ compiler"
      substituteInPlace "$PWD/warp/build_dll.py" \
        --replace-fail "g++" "c++"
    ''
    # Broken tests on aarch64. Since unittest doesn't support disabling a
    # single test, and pytest isn't compatible, we patch the test file directly
    # instead.
    #
    # See: https://github.com/NVIDIA/warp/issues/552
    + lib.optionalString stdenv.hostPlatform.isAarch64 ''
      substituteInPlace warp/tests/test_fem.py \
        --replace-fail "add_function_test(TestFem, \"test_integrate_gradient\", test_integrate_gradient, devices=devices)" ""
    + lib.optionalString effectiveStdenv.hostPlatform.isAarch64 ''
      nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on aarch64"
      substituteInPlace "$PWD/warp/tests/test_fem.py" \
        --replace-fail \
          'add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)' \
          ""
    ''
    # These tests fail on CPU and CUDA.
    + ''
      nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests"
      substituteInPlace "$PWD/warp/tests/test_reload.py" \
        --replace-fail \
          'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \
          "" \
        --replace-fail \
          'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \
          ""
    '';

  build-system = [
    build
    setuptools
  ];

@@ -163,11 +230,11 @@ buildPythonPackage {
    numpy
  ];

  nativeBuildInputs = lib.optionals libmathdxSupport [
    libmathdx
    cudaPackages.libcublas
    cudaPackages.libcufft
    cudaPackages.libnvjitlink
  # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp
  # will be loading GPU drivers at runtime, we need to inject the path to our video drivers.
  nativeBuildInputs = lib.optionals cudaSupport [
    autoAddDriverRunpath
    cudaPackages.cuda_nvcc
  ];

  buildInputs =
@@ -177,10 +244,18 @@ buildPythonPackage {
      llvmPackages.libcxx
    ]
    ++ lib.optionals cudaSupport [
      cudaPackages.cudatoolkit
      (lib.getOutput "static" cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available
      cudaPackages.cuda_cccl
      cudaPackages.cuda_cudart
      cudaPackages.cuda_nvcc
      cudaPackages.cuda_nvrtc
    ]
    ++ lib.optionals libmathdxSupport [
      libmathdx
      cudaPackages.libcublas
      cudaPackages.libcufft
      cudaPackages.libcusolver
      cudaPackages.libnvjitlink
    ];

  preBuild =
@@ -190,7 +265,8 @@ buildPythonPackage {
          "--no_standalone"
        ]
        ++ lib.optionals cudaSupport [
          "--cuda_path=${cudaPackages.cudatoolkit}"
          # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output).
          "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}"
        ]
        ++ lib.optionals libmathdxSupport [
          "--libmathdx"
@@ -203,34 +279,102 @@ buildPythonPackage {
      buildOptionString = lib.concatStringsSep " " buildOptions;
    in
    ''
      python build_lib.py ${buildOptionString}
      nixLog "running $PWD/build_lib.py to create components necessary to build the wheel"
      "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString}
    '';

  pythonImportsCheck = [
    "warp"
  ];

  # See passthru.tests.
  doCheck = false;

  passthru = {
    # Make libmathdx available for introspection.
    inherit libmathdx;

    # Scripts which provide test packages and implement test logic.
    testers.unit-tests = writeShellApplication {
      name = "warp-lang-unit-tests";
      runtimeInputs = [
        # Use the references from args
        (python.withPackages (_: [
          warp-lang
          jax
          torch
        ]))
        # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
        #  (paddlepaddle.override { inherit cudaSupport; })
      ];
      text = ''
        python3 -m warp.tests
      '';
    };

    # Tests run within the Nix sandbox.
    tests =
      let
        mkUnitTests =
          {
            cudaSupport,
            libmathdxSupport,
          }:
          let
            name =
              "warp-lang-unit-tests-cpu" # CPU is baseline
              + lib.optionalString cudaSupport "-cuda"
              + lib.optionalString libmathdxSupport "-libmathdx";

            warp-lang' = warp-lang.override {
              inherit cudaSupport libmathdxSupport;
              # Make sure the warp-lang provided through callPackage is replaced with the override we're making.
              warp-lang = warp-lang';
            };
          in
          runCommand name
            {
              nativeBuildInputs = [
                warp-lang'.passthru.testers.unit-tests
                writableTmpDirAsHomeHook
              ];
              requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ];
              # Many unit tests fail with segfaults on aarch64-linux, especially in the sim
              # and grad modules. However, other functionality generally works, so we don't
              # mark the package as broken.
              #
              # See: https://www.github.com/NVIDIA/warp/issues/{356,372,552}
  doCheck = !(stdenv.hostPlatform.isAarch64 && stdenv.hostPlatform.isLinux);

  nativeCheckInputs = [
    unittestCheckHook
    (jax.override { inherit cudaSupport; })
    (torch.override { inherit cudaSupport; })

    # # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
    #  (paddlepaddle.override { inherit cudaSupport; })
  ];

  preCheck = ''
    export WARP_CACHE_PATH=$(mktemp -d) # warp.config.kernel_cache_dir
              meta.broken = effectiveStdenv.hostPlatform.isAarch64 && effectiveStdenv.hostPlatform.isLinux;
            }
            ''
              nixLog "running ${name}"

              if warp-lang-unit-tests; then
                nixLog "${name} passed"
                touch "$out"
              else
                nixErrorLog "${name} failed"
                exit 1
              fi
            '';

  passthru.updateScript = nix-update-script { };
      in
      {
        cpu = mkUnitTests {
          cudaSupport = false;
          libmathdxSupport = false;
        };
        cuda = {
          cudaOnly = mkUnitTests {
            cudaSupport = true;
            libmathdxSupport = false;
          };
          cudaWithLibmathDx = mkUnitTests {
            cudaSupport = true;
            libmathdxSupport = true;
          };
        };
      };
  };

  meta = {
    description = "Python framework for high performance GPU simulation and graphics";