Commit 59d8696f authored by Daniel Fahey's avatar Daniel Fahey
Browse files

python3Packages.vllm: fix Blackwell support

This corrects the approach taken in b1cef6ca, which disabled Blackwell (SM100+) support due to CUTLASS API incompatibility.

My original diagnosis was incorrect: FlashMLA's SM100 kernels don't need newer CUTLASS APIs; they actually need CUTLASS v3.9.0 (pinned in FlashMLA's git submodule). The incompatibility is with vLLM's FlashMLA fork and CUTLASS v4.x APIs altogether.

This commit packages both CUTLASS versions: v3.9.0 for FlashMLA and v4.0.0 for the main vLLM build, allowing Blackwell support to work correctly. Although this subtlety can easily be expressed in the Nix derivation it was reported upstream in https://github.com/vllm-project/vllm/issues/27425
parent 27fd5766
Loading
Loading
Loading
Loading
+25 −14
Original line number Diff line number Diff line
@@ -104,6 +104,20 @@ let
    hash = "sha256-HJY+Go1viPkSVZPEs/NyMtYJzas4mMLiIZF3kNX+WgA=";
  };

  # FlashMLA's Blackwell (SM100) kernels were developed against CUTLASS v3.9.0
  # (since https://github.com/vllm-project/FlashMLA/commit/9c5dfab6d1746b4a27af14f440e7afd5c01ece68)
  # and are currently incompatible with CUTLASS v4.x APIs. The rest of the vLLM
  # build uses a newer CUTLASS, so we package both versions.
  # See upstream issue: https://github.com/vllm-project/vllm/issues/27425
  # See git submodule commit at:
  # https://github.com/vllm-project/FlashMLA/tree/${flashmla.src.rev}/csrc
  cutlass-flashmla = fetchFromGitHub {
    owner = "NVIDIA";
    repo = "cutlass";
    tag = "v3.9.0";
    hash = "sha256-Q6y/Z6vahASeSsfxvZDwbMFHGx8CnsF90IlveeVLO9g=";
  };

  flashmla = stdenv.mkDerivation {
    pname = "flashmla";
    # https://github.com/vllm-project/FlashMLA/blob/${src.rev}/setup.py
@@ -123,7 +137,7 @@ let
    # flashmla normally relies on `git submodule update` to fetch cutlass
    buildPhase = ''
      rm -rf csrc/cutlass
      ln -sf ${cutlass} csrc/cutlass
      ln -sf ${cutlass-flashmla} csrc/cutlass
    '';

    installPhase = ''
@@ -199,19 +213,16 @@ let
        "8.9"
        "9.0"
        "9.0a"
        # Blackwell (SM100+) capabilities temporarily disabled due to CUTLASS API incompatibility
        # FlashMLA kernels require CUTLASS v4.2.1+ APIs not available in bundled v4.0.0
        # TODO: Re-enable when vLLM upgrades CUTLASS (see https://github.com/vllm-project/vllm/pull/24673)
        # "10.0"
        # "10.0a"
        # "10.1"
        # "10.1a"
        # "10.3"
        # "10.3a"
        # "12.0"
        # "12.0a"
        # "12.1"
        # "12.1a"
        "10.0"
        "10.0a"
        "10.1"
        "10.1a"
        "10.3"
        "10.3a"
        "12.0"
        "12.0a"
        "12.1"
        "12.1a"
      ];
      ptx = lists.map (x: "${x}+PTX") real;
    in