diff --git a/README.md b/README.md index aee852a..4489d6c 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,21 @@ conda activate joyai pip install -e . ``` -> **Note on Flash Attention**: `flash-attn >= 2.8.0` is listed as a dependency for best performance. +> **Note on Flash Attention**: Flash Attention 3 is loaded through the [`kernels`](https://github.com/huggingface/kernels) library (installed by the command above), which fetches the pre-built [`kernels-community/flash-attn3`](https://huggingface.co/kernels-community/flash-attn3) binaries on first use — no local compilation required. Once resolved, the kernel is used exactly like the upstream `flash_attn_interface`, so behavior is identical. +> +> If your machine is not covered by the pre-built binaries, the code automatically falls back to a local Flash Attention 3 build. Build it from source (the `hopper` directory of [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)): +> +> ```bash +> git clone https://github.com/Dao-AILab/flash-attention.git +> cd flash-attention/hopper +> python setup.py install +> ``` +> +> If your FA3 build lives outside the import path, point to it with the `FLASH_ATTN3_PATH` environment variable: +> +> ```bash +> export FLASH_ATTN3_PATH=/path/to/flash-attention/hopper +> ``` #### Core Dependencies @@ -100,7 +114,8 @@ pip install -e . | `torch` | >= 2.8 | PyTorch | | `transformers` | >= 4.57.0, < 4.58.0 | Text encoder | | `diffusers` | >= 0.34.0 | Pipeline utilities | -| `flash-attn` | >= 2.8.0 | Fast attention kernel | +| `kernels` | latest | Loads pre-built Flash Attention 3 | +| `flash-attn` | 3 (Hopper build) | Fallback fast attention kernel | ### 2. Inference diff --git a/pyproject.toml b/pyproject.toml index 92f56eb..eb45d00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "accelerate", "diffusers==0.36.0", "einops", + "kernels", "loguru", "packaging", "pillow", @@ -21,7 +22,6 @@ dependencies = [ "torch==2.8.0", "torchvision", "transformers>=4.57.0,<4.58.0", - "flash-attn>=2.8.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 59ff70e..7133364 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ accelerate diffusers>=0.34.0 einops fastapi -flash-attn>=2.8.0 +kernels loguru openai packaging diff --git a/src/modules/models/attention.py b/src/modules/models/attention.py index 70cdd4b..cb461d3 100644 --- a/src/modules/models/attention.py +++ b/src/modules/models/attention.py @@ -1,22 +1,20 @@ # Adapted from https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo/attention -import os -import sys import torch from einops import rearrange -_FLASH_ATTN_IMPORT_ERROR = None +from modules.models.flash_attn_utils import load_flash_attn +# Flash Attention is loaded through the `kernels` library (pre-built FA2 +# binaries, no local compilation), with a transparent fallback to a source +# build / local installation of `flash-attn`. See `flash_attn_utils.py`. +_FLASH_ATTN_IMPORT_ERROR = None try: - # Check for Flash Attention 3 installation path - flash_attn3_path = os.getenv("FLASH_ATTN3_PATH") - if flash_attn3_path: - print(f"Using Flash Attention 3 from: {flash_attn3_path}") - sys.path.insert(0, flash_attn3_path) - from flash_attn_interface import flash_attn_varlen_func - else: - from flash_attn.flash_attn_interface import flash_attn_varlen_func -except ImportError as exc: + _flash_attn_symbols = load_flash_attn() + flash_attn_varlen_func = _flash_attn_symbols.get("flash_attn_varlen_func") + if flash_attn_varlen_func is None: + raise ImportError("flash_attn is not available via `kernels` or a source build.") +except Exception as exc: # noqa: BLE001 flash_attn_varlen_func = None _FLASH_ATTN_IMPORT_ERROR = exc diff --git a/src/modules/models/flash_attn_utils.py b/src/modules/models/flash_attn_utils.py new file mode 100644 index 0000000..0507e23 --- /dev/null +++ b/src/modules/models/flash_attn_utils.py @@ -0,0 +1,88 @@ +"""Loader for the Flash Attention symbols used by the attention backend. + +Building Flash Attention from source is slow. To avoid that, we first try to +load pre-built Flash Attention 3 binaries served through the Hugging Face +``kernels`` library. When pre-built binaries are not available for the current +platform (unsupported GPU/arch, no network, ...), we transparently fall back to +a source build / local installation of Flash Attention (the ``FLASH_ATTN3_PATH`` +override or the installed ``flash-attn`` package), so the dependency stays +optional. + +Inspired by https://github.com/OpenDriveLab/AgiBot-World/pull/159 +""" + +import logging +import os +import sys +from functools import lru_cache + +logger = logging.getLogger(__name__) + +# Kernel served by the `kernels` library that mirrors the Flash Attention 3 API. +_FLASH_ATTN_KERNEL = "kernels-community/flash-attn3" +_FLASH_ATTN_KERNEL_VERSION = 1 + + +@lru_cache(maxsize=None) +def load_flash_attn(): + """Return a dict of Flash Attention symbols, or an empty dict if unavailable. + + The returned dict exposes the callable the attention backend relies on: + ``flash_attn_varlen_func``. + + Resolution order: + 1. Pre-built FA3 binaries via the ``kernels`` library (no compilation). + 2. A source build / local installation of Flash Attention + (the original behaviour, including the ``FLASH_ATTN3_PATH`` override). + """ + symbols = _load_from_kernels() + if symbols is not None: + return symbols + + symbols = _load_from_source() + if symbols is not None: + return symbols + + return {} + + +def _load_from_kernels(): + try: + from kernels import get_kernel + + module = get_kernel(_FLASH_ATTN_KERNEL, version=_FLASH_ATTN_KERNEL_VERSION) + # The flash-attn3 kernel exposes its callables at the top level of the + # module (unlike flash-attn2, which nests them under + # `flash_attention_interface`). + logger.info( + "Loaded pre-built Flash Attention 3 binaries via `kernels` (%s).", + _FLASH_ATTN_KERNEL, + ) + return { + "flash_attn_varlen_func": module.flash_attn_varlen_func, + } + except Exception as exc: # noqa: BLE001 - any failure should trigger the source fallback + logger.info( + "Pre-built Flash Attention via `kernels` unavailable (%s); " + "falling back to a source build.", + exc, + ) + return None + + +def _load_from_source(): + try: + # Check for a Flash Attention 3 installation path first. + flash_attn3_path = os.getenv("FLASH_ATTN3_PATH") + if flash_attn3_path: + logger.info("Using Flash Attention 3 from: %s", flash_attn3_path) + sys.path.insert(0, flash_attn3_path) + from flash_attn_interface import flash_attn_varlen_func + else: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + return { + "flash_attn_varlen_func": flash_attn_varlen_func, + } + except ImportError: + return None