Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,21 @@ conda activate joyai
pip install -e .
```

> **Note on Flash Attention**: `flash-attn >= 2.8.0` is listed as a dependency for best performance.
> **Note on Flash Attention**: Flash Attention 3 is loaded through the [`kernels`](https://github.com/huggingface/kernels) library (installed by the command above), which fetches the pre-built [`kernels-community/flash-attn3`](https://huggingface.co/kernels-community/flash-attn3) binaries on first use — no local compilation required. Once resolved, the kernel is used exactly like the upstream `flash_attn_interface`, so behavior is identical.
>
> If your machine is not covered by the pre-built binaries, the code automatically falls back to a local Flash Attention 3 build. Build it from source (the `hopper` directory of [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)):
>
> ```bash
> git clone https://github.com/Dao-AILab/flash-attention.git
> cd flash-attention/hopper
> python setup.py install
> ```
>
> If your FA3 build lives outside the import path, point to it with the `FLASH_ATTN3_PATH` environment variable:
>
> ```bash
> export FLASH_ATTN3_PATH=/path/to/flash-attention/hopper
> ```

#### Core Dependencies

Expand All @@ -100,7 +114,8 @@ pip install -e .
| `torch` | >= 2.8 | PyTorch |
| `transformers` | >= 4.57.0, < 4.58.0 | Text encoder |
| `diffusers` | >= 0.34.0 | Pipeline utilities |
| `flash-attn` | >= 2.8.0 | Fast attention kernel |
| `kernels` | latest | Loads pre-built Flash Attention 3 |
| `flash-attn` | 3 (Hopper build) | Fallback fast attention kernel |


### 2. Inference
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"accelerate",
"diffusers==0.36.0",
"einops",
"kernels",
"loguru",
"packaging",
"pillow",
Expand All @@ -21,7 +22,6 @@ dependencies = [
"torch==2.8.0",
"torchvision",
"transformers>=4.57.0,<4.58.0",
"flash-attn>=2.8.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ accelerate
diffusers>=0.34.0
einops
fastapi
flash-attn>=2.8.0
kernels
loguru
openai
packaging
Expand Down
22 changes: 10 additions & 12 deletions src/modules/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
# Adapted from https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo/attention

import os
import sys
import torch
from einops import rearrange

_FLASH_ATTN_IMPORT_ERROR = None
from modules.models.flash_attn_utils import load_flash_attn

# Flash Attention is loaded through the `kernels` library (pre-built FA2
# binaries, no local compilation), with a transparent fallback to a source
# build / local installation of `flash-attn`. See `flash_attn_utils.py`.
_FLASH_ATTN_IMPORT_ERROR = None
try:
# Check for Flash Attention 3 installation path
flash_attn3_path = os.getenv("FLASH_ATTN3_PATH")
if flash_attn3_path:
print(f"Using Flash Attention 3 from: {flash_attn3_path}")
sys.path.insert(0, flash_attn3_path)
from flash_attn_interface import flash_attn_varlen_func
else:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError as exc:
_flash_attn_symbols = load_flash_attn()
flash_attn_varlen_func = _flash_attn_symbols.get("flash_attn_varlen_func")
if flash_attn_varlen_func is None:
raise ImportError("flash_attn is not available via `kernels` or a source build.")
except Exception as exc: # noqa: BLE001
flash_attn_varlen_func = None
_FLASH_ATTN_IMPORT_ERROR = exc

Expand Down
88 changes: 88 additions & 0 deletions src/modules/models/flash_attn_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Loader for the Flash Attention symbols used by the attention backend.

Building Flash Attention from source is slow. To avoid that, we first try to
load pre-built Flash Attention 3 binaries served through the Hugging Face
``kernels`` library. When pre-built binaries are not available for the current
platform (unsupported GPU/arch, no network, ...), we transparently fall back to
a source build / local installation of Flash Attention (the ``FLASH_ATTN3_PATH``
override or the installed ``flash-attn`` package), so the dependency stays
optional.

Inspired by https://github.com/OpenDriveLab/AgiBot-World/pull/159
"""

import logging
import os
import sys
from functools import lru_cache

logger = logging.getLogger(__name__)

# Kernel served by the `kernels` library that mirrors the Flash Attention 3 API.
_FLASH_ATTN_KERNEL = "kernels-community/flash-attn3"
_FLASH_ATTN_KERNEL_VERSION = 1


@lru_cache(maxsize=None)
def load_flash_attn():
"""Return a dict of Flash Attention symbols, or an empty dict if unavailable.

The returned dict exposes the callable the attention backend relies on:
``flash_attn_varlen_func``.

Resolution order:
1. Pre-built FA3 binaries via the ``kernels`` library (no compilation).
2. A source build / local installation of Flash Attention
(the original behaviour, including the ``FLASH_ATTN3_PATH`` override).
"""
symbols = _load_from_kernels()
if symbols is not None:
return symbols

symbols = _load_from_source()
if symbols is not None:
return symbols

return {}


def _load_from_kernels():
try:
from kernels import get_kernel

module = get_kernel(_FLASH_ATTN_KERNEL, version=_FLASH_ATTN_KERNEL_VERSION)
# The flash-attn3 kernel exposes its callables at the top level of the
# module (unlike flash-attn2, which nests them under
# `flash_attention_interface`).
logger.info(
"Loaded pre-built Flash Attention 3 binaries via `kernels` (%s).",
_FLASH_ATTN_KERNEL,
)
return {
"flash_attn_varlen_func": module.flash_attn_varlen_func,
}
except Exception as exc: # noqa: BLE001 - any failure should trigger the source fallback
logger.info(
"Pre-built Flash Attention via `kernels` unavailable (%s); "
"falling back to a source build.",
exc,
)
return None


def _load_from_source():
try:
# Check for a Flash Attention 3 installation path first.
flash_attn3_path = os.getenv("FLASH_ATTN3_PATH")
if flash_attn3_path:
logger.info("Using Flash Attention 3 from: %s", flash_attn3_path)
sys.path.insert(0, flash_attn3_path)
from flash_attn_interface import flash_attn_varlen_func
else:
from flash_attn.flash_attn_interface import flash_attn_varlen_func

return {
"flash_attn_varlen_func": flash_attn_varlen_func,
}
except ImportError:
return None