Commit 2aa951fa authored by SomeoneSerge's avatar SomeoneSerge
Browse files

python3Packages.triton.tests.axpy-cuda: init

parent ae560061
Loading
Loading
Loading
Loading
+29 −0
Original line number Diff line number Diff line
From c5d4087519eae6f41c80bbd8ffbcc9390db44c7f Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Thu, 10 Oct 2024 19:19:18 +0000
Subject: [PATCH] cmake.py: propagate cmakeFlags from environment

---
 tools/setup_helpers/cmake.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py
index 4b605fe5975..ea1d6a1ef46 100644
--- a/tools/setup_helpers/cmake.py
+++ b/tools/setup_helpers/cmake.py
@@ -332,6 +332,12 @@ class CMake:
                         file=sys.stderr,
                     )
                     print(e, file=sys.stderr)
+
+        # Nixpkgs compat:
+        if "cmakeFlags" in os.environ:
+            import shlex
+            args.extend(shlex.split(os.environ["cmakeFlags"]))
+
         # According to the CMake manual, we should pass the arguments first,
         # and put the directory as the last element. Otherwise, these flags
         # may not be passed correctly.
-- 
2.46.0
+46 −10
Original line number Diff line number Diff line
@@ -35,10 +35,8 @@
  removeReferencesTo,

  # Build inputs
  darwin,
  numactl,
  Accelerate,
  CoreServices,
  libobjc,

  # Propagated build inputs
  astunparse,
@@ -56,6 +54,17 @@
  tritonSupport ? (!stdenv.hostPlatform.isDarwin),
  triton,

  # TODO: 1. callPackage needs to learn to distinguish between the task
  #          of "asking for an attribute from the parent scope" and
  #          the task of "exposing a formal parameter in .override".
  # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
  #          as they routinely end up consuming the wrong arguments\
  #          (dependencies without cuda support).
  #          Instead we should rely on overlays and nixpkgsFun.
  # (@SomeoneSerge)
  _tritonEffective ? if cudaSupport then triton-cuda else triton,
  triton-cuda,

  # Unit tests
  hypothesis,
  psutil,
@@ -95,6 +104,8 @@ let
    ;
  inherit (cudaPackages) cudaFlags cudnn nccl;

  triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";

  rocmPackages = rocmPackages_5;

  setBool = v: if v then "1" else "0";
@@ -240,6 +251,7 @@ buildPythonPackage rec {
      # Allow setting PYTHON_LIB_REL_PATH with an environment variable.
      # https://github.com/pytorch/pytorch/pull/128419
      ./passthrough-python-lib-rel-path.patch
      ./0001-cmake.py-propagate-cmakeFlags-from-environment.patch
    ]
    ++ lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
    ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
@@ -257,7 +269,18 @@ buildPythonPackage rec {
    ];

  postPatch =
    lib.optionalString rocmSupport ''
    ''
      substituteInPlace cmake/public/cuda.cmake \
        --replace-fail \
          'message(FATAL_ERROR "Found two conflicting CUDA' \
          'message(WARNING "Found two conflicting CUDA' \
        --replace-warn \
          "set(CUDAToolkit_ROOT" \
          "# Upstream: set(CUDAToolkit_ROOT"
      substituteInPlace third_party/gloo/cmake/Cuda.cmake \
        --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
    ''
    + lib.optionalString rocmSupport ''
      # https://github.com/facebookincubator/gloo/pull/297
      substituteInPlace third_party/gloo/cmake/Hipify.cmake \
        --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
@@ -351,6 +374,17 @@ buildPythonPackage rec {
  # NB technical debt: building without NNPACK as workaround for missing `six`
  USE_NNPACK = 0;

  cmakeFlags =
    [
      # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
      (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaVersion)
    ]
    ++ lib.optionals cudaSupport [
      # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
      # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
      (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaVersion)
    ];

  preBuild = ''
    export MAX_JOBS=$NIX_BUILD_CORES
    ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
@@ -495,11 +529,11 @@ buildPythonPackage rec {
    ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
    ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
      Accelerate
      CoreServices
      libobjc
      darwin.apple_sdk.frameworks.Accelerate
      darwin.apple_sdk.frameworks.CoreServices
      darwin.libobjc
    ]
    ++ lib.optionals tritonSupport [ triton ]
    ++ lib.optionals tritonSupport [ _tritonEffective ]
    ++ lib.optionals MPISupport [ mpi ]
    ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];

@@ -527,7 +561,7 @@ buildPythonPackage rec {

    # torch/csrc requires `pybind11` at runtime
    pybind11
  ] ++ lib.optionals tritonSupport [ triton ];
  ] ++ lib.optionals tritonSupport [ _tritonEffective ];

  propagatedCxxBuildInputs =
    [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
@@ -662,7 +696,9 @@ buildPythonPackage rec {
      thoughtpolice
      tscholak
    ]; # tscholak esp. for darwin-related builds
    platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
    platforms =
      lib.platforms.linux
      ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
    broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
  };
}
+85 −18
Original line number Diff line number Diff line
@@ -15,7 +15,8 @@
  ninja,
  pybind11,
  python,
  runCommand,
  pytestCheckHook,
  stdenv,
  substituteAll,
  setuptools,
  torchWithRocm,
@@ -23,6 +24,7 @@
  cudaSupport ? config.cudaSupport,
  rocmSupport ? config.rocmSupport,
  rocmPackages,
  triton,
}:

buildPythonPackage {
@@ -164,16 +166,10 @@ buildPythonPackage {
  # CMake is run by setup.py instead
  dontUseCmakeConfigure = true;

  nativeCheckInputs = [
    cmake
    # Requires torch (circular dependency) and GPU access: pytestCheckHook
  ];
  nativeCheckInputs = [ cmake ];
  preCheck = ''
    # build/temp* refers to build_ext.build_temp (looked up in the build logs)
    (cd ./build/temp* ; ctest)

    # For pytestCheckHook
    cd test/unit
  '';

  pythonImportsCheck = [
@@ -181,20 +177,91 @@ buildPythonPackage {
    "triton.language"
  ];

  # Ultimately, torch is our test suite:
  passthru.gpuCheck = stdenv.mkDerivation {
    pname = "triton-pytest";
    inherit (triton) version src;

    requiredSystemFeatures = [ "cuda" ];

    nativeBuildInputs = [
      (python.withPackages (ps: [
        ps.scipy
        ps.torchWithCuda
        ps.triton-cuda
      ]))
    ];

    dontBuild = true;
    nativeCheckInputs = [ pytestCheckHook ];

    doCheck = true;

    preCheck = ''
      cd python/test/unit
      export HOME=$TMPDIR
    '';
    checkPhase = "pytestCheckPhase";

    installPhase = "touch $out";
  };

  passthru.tests = {
    # Ultimately, torch is our test suite:
    inherit torchWithRocm;
    # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
    # and pythonImportsCheck is commented back.
    import-triton =
      runCommand "import-triton"
        { nativeBuildInputs = [ (python.withPackages (ps: [ ps.triton ])) ]; }

    # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
    # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
    axpy-cuda =
      cudaPackages.writeGpuTestPython
        {
          libraries = ps: [
            ps.triton
            ps.torch-no-triton
          ];
        }
        ''
          python << \EOF
          # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html

          import triton
          import triton.language
          EOF
          touch "$out"
          import triton.language as tl
          import torch
          import os

          @triton.jit
          def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            offsets = block_start + tl.arange(0, BLOCK_SIZE)
            mask = offsets < n
            x = tl.load(x_ptr + offsets, mask=mask)
            y = tl.load(y_ptr + offsets, mask=mask)
            output = a * x + y
            tl.store(out + offsets, output, mask=mask)

          def axpy(a, x, y):
            output = torch.empty_like(x)
            assert x.is_cuda and y.is_cuda and output.is_cuda
            n_elements = output.numel()

            def grid(meta):
              return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

            axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
            return output

          if __name__ == "__main__":
            if os.environ.get("HOME", None) == "/homeless-shelter":
              os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
            if "CC" not in os.environ:
              os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
            torch.manual_seed(0)
            size = 12345
            x = torch.rand(size, device='cuda')
            y = torch.rand(size, device='cuda')
            output_torch = 3.14 * x + y
            output_triton = axpy(3.14, x, y)
            assert output_torch.sub(output_triton).abs().max().item() < 1e-6
            print("Triton axpy: OK")
        '';
  };

+4 −4
Original line number Diff line number Diff line
@@ -15717,10 +15717,10 @@ self: super: with self; {
  toposort = callPackage ../development/python-modules/toposort { };
  torch = callPackage ../development/python-modules/torch {
    inherit (pkgs.darwin.apple_sdk.frameworks) Accelerate CoreServices;
    inherit (pkgs.darwin) libobjc;
  };
  torch = callPackage ../development/python-modules/torch { };
  # Required to test triton
  torch-no-triton = self.torch.override { tritonSupport = false; };
  torch-audiomentations = callPackage ../development/python-modules/torch-audiomentations { };