Unverified Commit c5b01a14 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #222273 from SomeoneSerge/torch20

python3Packages.torch: 1.13.1 -> 2.0.0
parents 584278ef 632cff6f
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -24,6 +24,8 @@
, targetDir ? "llvm"
, targetProjects ? [ ]
, targetRuntimes ? [ ]
# "NATIVE" resolves into x86 or aarch64 depending on stdenv
, llvmTargetsToBuild ? [ "NATIVE" ]
, extraPatches ? [ ]
, extraNativeBuildInputs ? [ ]
, extraBuildInputs ? [ ]
@@ -46,6 +48,8 @@ let
    if stdenv.isx86_64 then "X86"
    else if stdenv.isAarch64 then "AArch64"
    else throw "Unsupported ROCm LLVM platform";
  inferNativeTarget = t: if t == "NATIVE" then llvmNativeTarget else t;
  llvmTargetsToBuild' = [ "AMDGPU" ] ++ builtins.map inferNativeTarget llvmTargetsToBuild;
in stdenv.mkDerivation (finalAttrs: {
  pname = "rocm-llvm-${targetName}";
  version = "5.4.4";
@@ -98,7 +102,7 @@ in stdenv.mkDerivation (finalAttrs: {
  sourceRoot = "${finalAttrs.src.name}/${targetDir}";

  cmakeFlags = [
    "-DLLVM_TARGETS_TO_BUILD=AMDGPU;${llvmNativeTarget}"
    "-DLLVM_TARGETS_TO_BUILD=${builtins.concatStringsSep ";" llvmTargetsToBuild'}"
  ] ++ lib.optionals (finalAttrs.passthru.isLLVM && targetProjects != [ ]) [
    "-DLLVM_ENABLE_PROJECTS=${lib.concatStringsSep ";" targetProjects}"
  ] ++ lib.optionals ((finalAttrs.passthru.isLLVM || targetDir == "runtimes") && targetRuntimes != [ ]) [
+11 −2
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
, stdenv
, buildDunePackage
, fetchFromGitHub
, fetchpatch
, cmdliner
, ctypes
, dune-configurator
@@ -29,6 +30,14 @@ buildDunePackage rec {
    hash = "sha256-z/9NUBjeFWE63Z/e8OyzDiy8hrn6qzjaiBH8G9MPeos=";
  };

  patches = [
    # Pytorch 2.0 support. Drop when it reaches a release
    (fetchpatch {
      url = "https://github.com/LaurentMazare/ocaml-torch/commit/ef7ef30cafecb09e45ec1ed8ce4bedae5947cfa5.patch";
      hash = "sha256-smdwKy40iIISp/25L2J4az6KmqFS1soeChBElUyhl5A=";
    })
  ];

  buildInputs = [ dune-configurator ];

  propagatedBuildInputs = [
+253 −0
Original line number Diff line number Diff line
{ lib
, buildPythonPackage
, python
, fetchpatch
, fetchFromGitHub
, addOpenGLRunpath
, cmake
, cudaPackages
, llvmPackages
, pybind11
, gtest
, zlib
, ncurses
, libxml2
, lit
, filelock
, torchWithRocm
, pytest
, pytestCheckHook
, pythonRelaxDepsHook
, pkgsTargetTarget
}:

let
  pname = "triton";
  version = "2.0.0";

  inherit (cudaPackages) cuda_cudart backendStdenv;

  # A time may come we'll want to be cross-friendly
  #
  # Short explanation: we need pkgsTargetTarget, because we use string
  # interpolation instead of buildInputs.
  #
  # Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
  # ptxas compiler. We're not running this ptxas on the build machine, but on
  # the user's machine, i.e. our Target platform. The second "Target" in
  # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
  # be executed on the GPU.
  # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
  ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas";

  llvm = (llvmPackages.llvm.override {
    llvmTargetsToBuild = [ "NATIVE" "NVPTX" ];
    # Upstream CI sets these too:
    # targetProjects = [ "mlir" ];
    extraCMakeFlags = [
      "-DLLVM_INSTALL_UTILS=ON"
    ];
  });
in
buildPythonPackage {
  inherit pname version;

  format = "setuptools";

  src = fetchFromGitHub {
    owner = "openai";
    repo = pname;
    rev = "v${version}";
    hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
  };

  patches = [
    # Prerequisite for llvm15 patch
    (fetchpatch {
      url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch";
      hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk=";
    })
    (fetchpatch {
      url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch";
      hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg=";
    })

    # Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch
    # The original patch adds ptxas binary, so we include our own clean copy
    # Drop with the next update
    ./llvm15.patch

    # TODO: there have been commits upstream aimed at removing the "torch"
    # circular dependency, but the patches fail to apply on the release
    # revision. Keeping the link for future reference
    # Also cf. https://github.com/openai/triton/issues/1374

    # (fetchpatch {
    #   url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
    #   hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
    # })
  ];

  postPatch = ''
    substituteInPlace python/setup.py \
      --replace \
        '= get_thirdparty_packages(triton_cache_path)' \
        '= os.environ["cmakeFlags"].split()'
  ''
  # Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3
  # Revisit when updating either triton or llvm
  + ''
    substituteInPlace CMakeLists.txt \
      --replace "nvptx" "NVPTX" \
      --replace "LLVM 11" "LLVM"
    sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt
    sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt
    find -iname '*.td' -exec \
      sed -i \
      -e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \
      -e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \
      '{}' ';'
    substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
    sed -i 's/^include.*$//' unittest/CMakeLists.txt
    sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt
    sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt
  ''
  # TritonMLIRIR already links MLIRIR. Not transitive?
  # + ''
  #   echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt
  # ''
  # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
  + ''
    substituteInPlace bin/CMakeLists.txt \
      --replace "add_subdirectory(FileCheck)" ""

    rm cmake/FindLLVM.cmake
  ''
  +
  (
    let
      # Bash was getting weird without linting,
      # but basically upstream contains [cc, ..., "-lcuda", ...]
      # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
      old = [ "-lcuda" ];
      new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ];

      quote = x: ''"${x}"'';
      oldStr = lib.concatMapStringsSep ", " quote old;
      newStr = lib.concatMapStringsSep ", " quote new;
    in
    ''
      substituteInPlace python/triton/compiler.py \
        --replace '${oldStr}' '${newStr}'
    ''
  )
  # Triton seems to be looking up cuda.h
  + ''
    sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py
  '';

  nativeBuildInputs = [
    cmake
    pythonRelaxDepsHook

    # Requires torch (circular dependency) and probably needs GPUs:
    # pytestCheckHook

    # Note for future:
    # These *probably* should go in depsTargetTarget
    # ...but we cannot test cross right now anyway
    # because we only support cudaPackages on x86_64-linux atm
    lit
    llvm
    llvmPackages.mlir
  ];

  buildInputs = [
    gtest
    libxml2.dev
    ncurses
    pybind11
    zlib
  ];

  propagatedBuildInputs = [
    filelock
  ];

  # Avoid GLIBCXX mismatch with other cuda-enabled python packages
  preConfigure = ''
    export CC="${backendStdenv.cc}/bin/cc";
    export CXX="${backendStdenv.cc}/bin/c++";

    # Upstream's setup.py tries to write cache somewhere in ~/
    export HOME=$TMPDIR

    # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
    echo "
    [build_ext]
    base-dir=$PWD" >> python/setup.cfg

    # The rest (including buildPhase) is relative to ./python/
    cd python/

    # Work around download_and_copy_ptxas()
    dst_cuda="$PWD/triton/third_party/cuda/bin"
    mkdir -p "$dst_cuda"
    ln -s "${ptxas}" "$dst_cuda/"
  '';

  # CMake is run by setup.py instead
  dontUseCmakeConfigure = true;
  cmakeFlags = [
    "-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir"
  ];

  postFixup =
    let
      ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas";
    in
    # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
    ''
      rm -f ${ptxasDestination}
      ln -s ${ptxas} ${ptxasDestination}
    '';

  checkInputs = [
    cmake # ctest
  ];
  dontUseSetuptoolsCheck = true;
  preCheck =
    # build/temp* refers to build_ext.build_temp (looked up in the build logs)
    ''
      (cd /build/source/python/build/temp* ; ctest)
    '' # For pytestCheckHook
    + ''
      cd test/unit
    '';
  pythonImportsCheck = [
    # Circular dependency on torch
    # "triton"
    # "triton.language"
  ];

  # Ultimately, torch is our test suite:
  passthru.tests = {
    inherit torchWithRocm;
  };

  pythonRemoveDeps = [
    # Circular dependency, cf. https://github.com/openai/triton/issues/1374
    "torch"

    # CLI tools without dist-info
    "cmake"
    "lit"
  ];
  meta = with lib; {
    description = "Development repository for the Triton language and compiler";
    homepage = "https://github.com/openai/triton/";
    platforms = lib.platforms.unix;
    license = licenses.mit;
    maintainers = with maintainers; [ SomeoneSerge ];
  };
}
+4617 −0

File added.

Preview size limit exceeded, changes collapsed.

+44 −18
Original line number Diff line number Diff line
@@ -6,12 +6,18 @@

  # Native build inputs
  cmake, util-linux, linkFarm, symlinkJoin, which, pybind11, removeReferencesTo,
  pythonRelaxDepsHook,

  # Build inputs
  numactl,
  Accelerate, CoreServices, libobjc,

  # Propagated build inputs
  filelock,
  jinja2,
  networkx,
  openai-triton,
  sympy,
  numpy, pyyaml, cffi, click, typing-extensions,

  # Unit tests
@@ -49,9 +55,7 @@ let
  inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
in

# assert that everything needed for cuda is present and that the correct cuda versions are used
assert !cudaSupport || (let majorIs = lib.versions.major cudatoolkit.version;
                        in majorIs == "9" || majorIs == "10" || majorIs == "11");
assert cudaSupport -> (cudaPackages.cudaMajorVersion == "11");

# confirm that cudatoolkits are sync'd across dependencies
assert !(MPISupport && cudaSupport) || mpi.cudatoolkit == cudatoolkit;
@@ -129,10 +133,10 @@ let
in buildPythonPackage rec {
  pname = "torch";
  # Don't forget to update torch-bin to the same version.
  version = "1.13.1";
  version = "2.0.0";
  format = "setuptools";

  disabled = pythonOlder "3.7.0";
  disabled = pythonOlder "3.8.0";

  outputs = [
    "out" # output standard python package
@@ -145,7 +149,7 @@ in buildPythonPackage rec {
    repo = "pytorch";
    rev = "refs/tags/v${version}";
    fetchSubmodules = true;
    hash = "sha256-yQz+xHPw9ODRBkV9hv1th38ZmUr/fXa+K+d+cvmX3Z8=";
    hash = "sha256-cSw7+AYBUcZLz3UyK/+JWWjQxKwVBXcFvBq0XAcL3tE=";
  };

  patches = lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
@@ -155,15 +159,6 @@ in buildPythonPackage rec {
    # base is 10.12. Until we upgrade, we can fall back on the older
    # pthread support.
    ./pthreadpool-disable-gcd.diff
  ] ++ [
    # PyTorch fails to build on gcc 12 due to gloo
    # https://github.com/pytorch/pytorch/issues/77614
    (fetchpatch {
      url = "https://github.com/facebookincubator/gloo/commit/4a5e339b764261d20fc409071dc7a8b8989aa195.patch";
      stripLen = 1;
      extraPrefix = "third_party/gloo/";
      hash = "sha256-UxR1r7F6g76BWj3GBIrSy5t+YZDCWy6mMddwx+hon5w=";
    })
  ];

  postPatch = lib.optionalString rocmSupport ''
@@ -261,7 +256,16 @@ in buildPythonPackage rec {
  # Suppress gcc regression: avx512 math function raises uninitialized variable warning
  # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
  # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
  ++ lib.optionals stdenv.cc.isGNU [ "-Wno-error=maybe-uninitialized" "-Wno-error=uninitialized" ]));
  ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
    "-Wno-error=maybe-uninitialized"
    "-Wno-error=uninitialized"
  ]
  # Since pytorch 2.0:
  # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
  # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
  ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12" ) [
    "-Wno-error=free-nonheap-object"
  ]));

  nativeBuildInputs = [
    cmake
@@ -269,6 +273,7 @@ in buildPythonPackage rec {
    which
    ninja
    pybind11
    pythonRelaxDepsHook
    removeReferencesTo
  ] ++ lib.optionals cudaSupport [ cudatoolkit_joined ]
    ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
@@ -286,11 +291,27 @@ in buildPythonPackage rec {
    click
    numpy
    pyyaml

    # From install_requires:
    filelock
    typing-extensions
    sympy
    networkx
    jinja2

    # the following are required for tensorboard support
    pillow six future tensorboard protobuf
  ] ++ lib.optionals MPISupport [ mpi ]
    ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
  ]
  ++ lib.optionals MPISupport [ mpi ]
  ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]
  # rocm build requires openai-triton;
  # openai-triton currently requires cuda_nvcc,
  # so not including it in the cpu-only build;
  # torch.compile relies on openai-triton,
  # so we include it for the cuda build as well
  ++ lib.optionals (rocmSupport || cudaSupport) [
    openai-triton
  ];

  # Tests take a long time and may be flaky, so just sanity-check imports
  doCheck = false;
@@ -318,6 +339,11 @@ in buildPythonPackage rec {
    "runHook postCheck"
  ];

  pythonRemoveDeps = [
    # In our dist-info the name is just "triton"
    "pytorch-triton-rocm"
  ];

  postInstall = ''
    find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +

Loading