Unverified Commit 9d4915aa authored by Yt's avatar Yt Committed by GitHub
Browse files

python3Packages.vllm: fix Blackwell support (#455364)

parents c776f8d5 59d8696f
Loading
Loading
Loading
Loading
+25 −14
Original line number Diff line number Diff line
@@ -104,6 +104,20 @@ let
    hash = "sha256-HJY+Go1viPkSVZPEs/NyMtYJzas4mMLiIZF3kNX+WgA=";
  };

  # FlashMLA's Blackwell (SM100) kernels were developed against CUTLASS v3.9.0
  # (since https://github.com/vllm-project/FlashMLA/commit/9c5dfab6d1746b4a27af14f440e7afd5c01ece68)
  # and are currently incompatible with CUTLASS v4.x APIs. The rest of the vLLM
  # build uses a newer CUTLASS, so we package both versions.
  # See upstream issue: https://github.com/vllm-project/vllm/issues/27425
  # See git submodule commit at:
  # https://github.com/vllm-project/FlashMLA/tree/${flashmla.src.rev}/csrc
  cutlass-flashmla = fetchFromGitHub {
    owner = "NVIDIA";
    repo = "cutlass";
    tag = "v3.9.0";
    hash = "sha256-Q6y/Z6vahASeSsfxvZDwbMFHGx8CnsF90IlveeVLO9g=";
  };

  flashmla = stdenv.mkDerivation {
    pname = "flashmla";
    # https://github.com/vllm-project/FlashMLA/blob/${src.rev}/setup.py
@@ -123,7 +137,7 @@ let
    # flashmla normally relies on `git submodule update` to fetch cutlass
    buildPhase = ''
      rm -rf csrc/cutlass
      ln -sf ${cutlass} csrc/cutlass
      ln -sf ${cutlass-flashmla} csrc/cutlass
    '';

    installPhase = ''
@@ -199,19 +213,16 @@ let
        "8.9"
        "9.0"
        "9.0a"
        # Blackwell (SM100+) capabilities temporarily disabled due to CUTLASS API incompatibility
        # FlashMLA kernels require CUTLASS v4.2.1+ APIs not available in bundled v4.0.0
        # TODO: Re-enable when vLLM upgrades CUTLASS (see https://github.com/vllm-project/vllm/pull/24673)
        # "10.0"
        # "10.0a"
        # "10.1"
        # "10.1a"
        # "10.3"
        # "10.3a"
        # "12.0"
        # "12.0a"
        # "12.1"
        # "12.1a"
        "10.0"
        "10.0a"
        "10.1"
        "10.1a"
        "10.3"
        "10.3a"
        "12.0"
        "12.0a"
        "12.1"
        "12.1a"
      ];
      ptx = lists.map (x: "${x}+PTX") real;
    in