diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 00000000000..8c529583c72 --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,225 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v5 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.29 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + # detect if we're on ARM + if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + PLAT=linux_aarch64 + else + PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 + fi + echo "PLAT=$PLAT" >> $GITHUB_ENV + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 + pip install jinja2 + TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl + TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl + pip install --no-cache-dir --pre "${TRITON_URL}" + pip install --no-cache-dir --pre "${TORCH_URL}" + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export FLASH_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000000..25ea5e86b75 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,47 @@ +name: Build wheels + +on: + workflow_dispatch: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000000..bc304a5641a --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,33 @@ +name: Lint + +on: + pull_request: + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + push: + branches: + - main + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6205ebf4b69..47f374ade99 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,16 +13,16 @@ on: - v* jobs: - setup_release: name: Create Release runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} shell: bash - - name: Create Release id: create_release uses: actions/create-release@v1 @@ -35,161 +35,50 @@ jobs: build_wheels: name: Build Wheel needs: setup_release - runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1'] - cuda-version: ['12.9.0'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.4.0' - python-version: '3.13' - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Set CUDA and PyTorch versions - run: | - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - - - name: Free up disk space - if: ${{ runner.os == 'Linux' }} - # https://github.com/easimon/maximize-build-space/blob/master/action.yml - # https://github.com/easimon/maximize-build-space/tree/test-report - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - - - name: Set up swap space - if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 - with: - swap-size-gb: 10 - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.25 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - sub-packages: '["nvcc"]' - - - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error - # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable - pip install typing-extensions==4.12.2 - # We want to figure out the CUDA version to download pytorch - # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04, ubuntu-22.04-arm] + python-version: ["3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.1"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + include: + - torch-version: "2.9.1" + cuda-version: "13.0.2" + python-version: "3.14" + - torch-version: "2.10.0.dev20251108" + cuda-version: "13.0.2" + exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \ - print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ - ) - if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 - pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} - fi - nvcc --version - python --version - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==75.8.0 - pip install ninja packaging wheel - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Limit MAX_JOBS otherwise the github runner goes OOM - # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) NVCC_THREADS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Log Built Wheels - run: | - ls dist - - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - - - name: Get Release with tag - id: get_current_release - uses: joutvhu/get-release@v1 - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true publish_package: name: Publish package needs: [build_wheels] - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 with: - python-version: '3.10' - + python-version: "3.10" - name: Install dependencies run: | pip install ninja packaging wheel twine @@ -197,13 +86,11 @@ jobs: pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu - - name: Build core package env: FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist - - name: Deploy env: TWINE_USERNAME: "__token__" diff --git a/.gitignore b/.gitignore index 1f1f8028863..dc508654045 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.ncu-rep .DS_store +.vscode # Byte-compiled / optimized / DLL files __pycache__/ @@ -26,6 +27,10 @@ var/ # IDE-related .idea/ +.vscode/ # Dev venv + +# compile-time generated file +flash_attn_config.py \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..6118dfa2283 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + files: ^flash_attn/cute/.*\.py$ + exclude: &cute_exclude | + (?x)^flash_attn/cute/( + flash_bwd| + flash_fwd| + flash_fwd_sm100| + interface| + )\.py$ + - id: ruff-format + files: ^flash_attn/cute/.*\.py$ + exclude: *cute_exclude diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py new file mode 100644 index 00000000000..cb6bc44eae2 --- /dev/null +++ b/benchmarks/benchmark_attn.py @@ -0,0 +1,417 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None +# cudnn = None + +Timing = NamedTuple('timing', [('mean', float)]) + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler + +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +except ImportError: + flash_attn_func = None + flash_attn_varlen_func = None +from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python +from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python +try: + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 +except ImportError: + flash_attn_func_v3 = None + flash_attn_varlen_func_v3 = None + +if torch.cuda.get_device_capability()[0] != 9: + flash_attn_func_v3 = None +# flash_attn_func_v3 = None + +flash_attn_func = None + +from triton.testing import do_bench + +def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): + # # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty((b, nheads, seqlen_q, headdim_v), dtype=q.dtype, device=q.device) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + + o, stats = graph.sdpa( + name="sdpa", + q=q, + k=k, + v=v, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + stats: stats_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert g.shape == (b, nheads, seqlen_q, headdim_v) + assert o.shape == (b, nheads, seqlen_q, headdim_v) + assert lse.shape == (b, nheads, seqlen_q, 1) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g + dq_gpu = torch.empty_like(q_gpu) + dk_gpu = torch.empty_like(k_gpu) + dv_gpu = torch.empty_like(v_gpu) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + o = graph.tensor_like(o_gpu.detach()) + g = graph.tensor_like(g_gpu.detach()) + stats = graph.tensor_like(lse.detach()) + + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=q, + k=k, + v=v, + o=o, + dO=g, + stats=stats, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + use_deterministic_algorithm=False, + ) + + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + g: g_gpu, + stats: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return run + + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = False +has_backward = True +page_size = None +# page_size = 128 +softcap = 0.0 +V_colmajor = False +deterministic = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 +# for headdim in [64, 128, 256]: +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +# bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] +# bs_seqlen_vals = [(32, 512), (16, 1024)] +# bs_seqlen_vals = [(2, 64 * 132)] +bs_seqlen_vals = [(4, 8192)] +# bs_seqlen_vals = [(1, 16 * 1024)] +time_f = {} +time_b = {} + +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192]: +# for headdim in [64, 96, 128, 192, 256]: +# for headdim in [64, 96, 128]: +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192, 256]: +for headdim in [128]: + # nheads = dim // headdim + nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8 + # nheads = 128 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + nheads_kv = nheads + # nheads_kv = nheads // 8 + # nheads_kv = 1 + # headdim_v = headdim + headdim_v = 128 if headdim == 192 else headdim + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False + # sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + sinks = None + + for batch_size, seqlen in bs_seqlen_vals: + num_splits = 0 + # window_size = (-1, -1) + window_size = (None, None) + window_size_fa = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + pack_gqa = None + # seqlen_q = 64 + seqlen_q = seqlen + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) + q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_(has_backward) + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) + if varlen: + q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None + # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:256] + # seqlen_q = 256 + # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:384] + # seqlen_q = 384 + if page_size is not None: + assert seqlen % page_size == 0 + k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if has_backward and headdim == headdim_v: + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + if not varlen: + m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + else: + m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean + if has_backward: + time.sleep(1) + if not varlen: + _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + else: + _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) + + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') + time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean + if has_backward: + time.sleep(1) + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean + # pytorch_profiler(cudnn_spda, backward=False) + # pytorch_profiler(cudnn_spda_bwd, backward=False) + time.sleep(1) + if flash_attn_func_v3 is not None: + if not varlen: + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) + else: + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) + time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean + if flash_attn_func_python is not None: + if not varlen: + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + else: + m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: + time.sleep(1) + if not varlen: + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3') + else: + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav3') + time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean + time.sleep(1) + # if not varlen: + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) + # else: + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) + # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python') + + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') + if flash_attn_func_v3 is not None: + print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') + + if flash_attn_func_python is not None: + print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 6c4797c83e0..c97581c6581 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -17,12 +17,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func -try: - from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax -except ImportError: - scaled_upper_triang_masked_softmax = None - - def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: @@ -52,27 +46,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): return output.to(dtype=qkv.dtype) -def attention_megatron(qkv): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0) - output = torch.einsum('bhts,bshd->bthd', attention, v) - return output.to(dtype=qkv.dtype) - - torch.manual_seed(0) repeats = 30 batch_size = 8 @@ -130,9 +103,6 @@ def attention_megatron(qkv): # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') # # pytorch_profiler(attention, q, k, v, 1.0, backward=True) -# if scaled_upper_triang_masked_softmax is not None: -# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') - # from src.ops.fftconv import fftconv_func # dim = nheads * headdim diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 341ae4b2139..9624ba0c334 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -54,7 +54,7 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + # Adding is faster than masked_fill_ scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) @@ -88,53 +88,65 @@ def time_fwd_bwd(func, *args, **kwargs): speed_f = {} speed_b = {} speed_f_b = {} + for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim - qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - f, b = time_fwd_bwd( - flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False - ) - time_f[config, "Flash2"] = f - time_b[config, "Flash2"] = b - - try: - qkv = qkv.detach().requires_grad_(True) + + # FlashAttention 2 + if "Flash2" in methods: + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) f, b = time_fwd_bwd( - attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, + repeats=repeats, verbose=False ) - except: # Skip if OOM - f, b = float('nan'), float('nan') - time_f[config, "Pytorch"] = f - time_b[config, "Pytorch"] = b - - if attention_triton is not None: - q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - # Try both values of sequence_parallel and pick the faster one + time_f[config, "Flash2"] = f + time_b[config, "Flash2"] = b + + # PyTorch baseline + if "Pytorch" in methods: + try: + # fresh tensor avoids grad-history reuse issues + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + f, b = time_fwd_bwd( + attention_pytorch, qkv, dropout_p, causal=causal, + repeats=repeats, verbose=False + ) + except Exception: + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + # Triton + if "Triton" in methods and attention_triton is not None: + q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] + # Try both values of sequence_parallel and pick the faster backward try: f, b = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), False, repeats=repeats, verbose=False ) - except: + except Exception: f, b = float('nan'), float('inf') try: _, b0 = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), True, repeats=repeats, verbose=False ) - except: + except Exception: b0 = float('inf') time_f[config, "Triton"] = f time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] + # xFormers CUTLASS + if "xformers.c" in methods and xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, @@ -143,9 +155,10 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "xformers.c"] = f time_b[config, "xformers.c"] = b - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] + # xFormers Flash + if "xformers.f" in methods and xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, @@ -154,8 +167,11 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "xformers.f"] = f time_b[config, "xformers.f"] = b + # Report print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") for method in methods: + if (config, method) not in time_f or (config, method) not in time_b: + continue time_f_b[config, method] = time_f[config, method] + time_b[config, method] speed_f[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), @@ -175,6 +191,5 @@ def time_fwd_bwd(func, *args, **kwargs): f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" ) - # with open('flash2_attn_time.plk', 'wb') as fp: -# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/benchmarks/cute/benchmark_block_sparsity.py b/benchmarks/cute/benchmark_block_sparsity.py new file mode 100644 index 00000000000..74f220e8795 --- /dev/null +++ b/benchmarks/cute/benchmark_block_sparsity.py @@ -0,0 +1,363 @@ +""" +Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. +""" + +import torch +from dataclasses import dataclass +from typing import Callable, Optional, List +from tabulate import tabulate +from tqdm import tqdm +import itertools + +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.testing import benchmark as cute_benchmark +import cutlass.cute as cute +from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.mask_definitions import ( + get_mask_pair, + random_doc_id_tensor, + flex_document_mask, + cute_document_mask, +) + +from torch.nn.attention.flex_attention import create_block_mask +from triton.testing import do_bench + +# Configure torch.compile cache to prevent memory buildup +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + batch_size: int + num_heads: int + seqlen_q: int + seqlen_k: int + mask_name: str + tile_m: int = 128 + tile_n: int = 128 + use_fast_sampling: bool = False + aux_tensors_cute: Optional[list] = None + + +@dataclass(frozen=True) +class BenchmarkResult: + """Result of a single benchmark run.""" + + config: BenchmarkConfig + cute_time_ms: Optional[float] + pytorch_time_ms: Optional[float] + error_message: Optional[str] = None + + +def benchmark_pytorch_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark PyTorch block mask creation (compiled). + Returns: creation_time_ms + """ + device = "cuda" + + try: + cbm = torch.compile(create_block_mask) + + def run_benchmark(): + return cbm( + mask_fn, + config.batch_size, + config.num_heads, + config.seqlen_q, + config.seqlen_k, + device=device, + ) + + creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) + + return creation_time_ms + + except Exception as e: + print(f"PyTorch benchmark failed ({config.mask_name}): {e}") + import traceback + traceback.print_exc() + return None + + +def benchmark_cute_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark CuTe block sparsity kernel. + Returns: creation_time_ms + """ + device = "cuda" + + try: + num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m + num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + mask_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + full_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + + # Convert to CuTe tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + # Create kernel + use_aux = config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 + kernel = BlockSparsityKernel( + mask_mod=mask_fn, + tile_mn=(config.tile_m, config.tile_n), + compute_full_blocks=True, + use_aux_tensors=use_aux, + use_fast_sampling=config.use_fast_sampling, + ) + + # Compile kernel + compiled_kernel = cute.compile( + kernel, + blocksparse_tensors, + config.seqlen_q, + config.seqlen_k, + config.aux_tensors_cute, + ) + + def generate_tensors(): + from cutlass.cute.testing import JitArguments + + return JitArguments( + blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute + ) + + creation_time_us = cute_benchmark( + compiled_kernel, + workspace_generator=generate_tensors, + warmup_iterations=10, + iterations=100, + ) + + torch.cuda.synchronize(device) + creation_time_ms = creation_time_us / 1000.0 + + return creation_time_ms + + except Exception as e: + print(f"CuTe benchmark failed: {e}") + return None + + +def run_benchmark( + config: BenchmarkConfig, + pytorch_mask_fn: Callable, + cute_mask_fn: Callable, +) -> BenchmarkResult: + """Run benchmarks for both implementations.""" + + print( + f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " + f"M={config.seqlen_q}, N={config.seqlen_k}" + ) + + # Benchmark PyTorch + pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) + + # Benchmark CuTe + cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) + + return BenchmarkResult( + config=config, + cute_time_ms=cute_time, + pytorch_time_ms=pytorch_time, + ) + + +def generate_configs( + batch_sizes: List[int], + num_heads: List[int], + seqlens: List[int], + mask_names: List[str], +) -> List[BenchmarkConfig]: + """Generate all benchmark configurations.""" + configs = [] + for B, H, S, mask_name in itertools.product(batch_sizes, num_heads, seqlens, mask_names): + configs.append( + BenchmarkConfig( + batch_size=B, + num_heads=H, + seqlen_q=S, + seqlen_k=S, + mask_name=mask_name, + ) + ) + return configs + + +def print_results(results: List[BenchmarkResult]): + successful_results = [ + r for r in results if r.cute_time_ms is not None and r.pytorch_time_ms is not None + ] + + if not successful_results: + print("No successful benchmark results to display") + return + + headers = ["B", "H", "M", "N", "Mask Type", "CuTe Time (ms)", "PyTorch Time (ms)", "Speedup"] + + rows = [] + for result in successful_results: + speedup = result.pytorch_time_ms / result.cute_time_ms if result.cute_time_ms > 0 else 0 + + rows.append( + [ + result.config.batch_size, + result.config.num_heads, + result.config.seqlen_q, + result.config.seqlen_k, + result.config.mask_name, + f"{result.cute_time_ms:.4f}", + f"{result.pytorch_time_ms:.4f}", + f"{speedup:.2f}x", + ] + ) + + # Sort by batch, head, seqlen, then mask type + rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) + + print("\n" + "=" * 100) + print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") + print("=" * 100) + print(tabulate(rows, headers=headers, tablefmt="github")) + print("=" * 100) + + +def main(): + """Run the comparative benchmark.""" + + # Configuration + batch_sizes = [1, 4, 8] + num_heads = [8, 16] + seqlens = [1024, 2048, 4096, 8192] + mask_names = [ + "causal", + "sliding_window", + "prefix_lm", + "dilated_sliding_window", + "document", + ] + + device = "cuda" + max_seqlen = max(seqlens) + max_batch = max(batch_sizes) + max_heads = max(num_heads) + + # Create document IDs using the helper from mask_definitions + doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) + doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + + # Generate base configurations + base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) + + # Update configs with aux tensors for document masking + configs = [] + for config in base_configs: + if config.mask_name == "document": + # Add aux tensors for document masking + configs.append( + BenchmarkConfig( + batch_size=config.batch_size, + num_heads=config.num_heads, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + mask_name=config.mask_name, + tile_m=config.tile_m, + tile_n=config.tile_n, + use_fast_sampling=False, + aux_tensors_cute=[doc_ids_cute], + ) + ) + else: + configs.append(config) + + # Run benchmarks + results = [] + print(f"Running {len(configs)} benchmark configurations...") + for config in tqdm(configs, desc="Benchmarking"): + try: + # Get mask pair from mask_definitions + mask_kwargs = {} + if config.mask_name == "sliding_window": + mask_kwargs["window_size"] = 128 # Default window size + + cute_mask_fn, pytorch_mask_fn = get_mask_pair( + config.mask_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + **mask_kwargs, + ) + + # For document masking, create wrapper that captures doc_ids + if config.mask_name == "document": + # PyTorch wrapper + def pytorch_mask_fn(b, h, q, kv): + return flex_document_mask(b, h, q, kv, doc_ids) + # CuTe wrapper - reuse cute_document_mask with aux_tensors + cute_mask_fn = cute_document_mask + + result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) + results.append(result) + + except Exception as e: + print(f"Failed to run config {config}: {e}") + results.append( + BenchmarkResult( + config=config, + cute_time_ms=None, + pytorch_time_ms=None, + error_message=str(e), + ) + ) + finally: + torch.cuda.empty_cache() + torch._dynamo.reset() + + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/cute/benchmark_mask_mod.py b/benchmarks/cute/benchmark_mask_mod.py new file mode 100644 index 00000000000..348d2ee485d --- /dev/null +++ b/benchmarks/cute/benchmark_mask_mod.py @@ -0,0 +1,686 @@ +""" +FlashAttention benchmarking script with Flex Attention-style +mask mod support and varlen sequences. +""" + +from dataclasses import dataclass +import math +from typing import Any, Dict, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import numpy as np +import torch + +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute.mask_definitions import ( + get_mask_pair, + random_doc_id_tensor, +) +from flash_attn.cute.block_sparsity import ( + compute_block_sparsity, + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, +) + + +@dataclass +class BenchmarkConfig: + """Benchmark configuration""" + + # Model parameters + headdim: int + headdim_v: int + nheads: int + nheads_kv: int + dtype: torch.dtype + + # Sequence parameters + batch_size: int = 2 + seqlen_q: int = 8192 + seqlen_k: int = 8192 + + # Varlen parameters + use_varlen: bool = False + min_seqlen_q: Optional[int] = None # If None, use seqlen_q // 2 + max_seqlen_q: Optional[int] = None # If None, use seqlen_q + min_seqlen_k: Optional[int] = None # If None, use seqlen_k // 2 + max_seqlen_k: Optional[int] = None # If None, use seqlen_k + + # Mask parameters + use_mask_mod: bool = True + mask_mod_name: str = "causal" + has_aux_tensors: bool = mask_mod_name == "document" + + # Sliding window parameter (used when mask_mod_name == "sliding_window") + window_size: int = 128 + + # Attention parameters + causal: bool = False + is_local: bool = False + window_left: Optional[int] = 128 # For base Flash Attention local + window_right: Optional[int] = 0 # For base Flash Attention local + softcap: Optional[float] = None + use_learnable_sink: bool = False + + # Kernel configuration + tile_m: int = 128 + tile_n: int = 128 + num_stages: int = 2 + num_threads: int = 384 + intra_wg_overlap: bool = True + mma_pv_is_rs: bool = True + + # Benchmark parameters + warmup_iters: int = 10 + benchmark_iters: int = 25 + verbose: bool = False + seed: int = 42 + + +class FlashAttentionBenchmark: + def __init__(self, config: BenchmarkConfig): + self.config = config + + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # Verify SM90 compute capability + compute_capability = torch.cuda.get_device_capability() + assert compute_capability >= (9, 0), ( + f"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}" + ) + # causal overrides use_mask_mod + if config.causal: + config.use_mask_mod = False + + if config.use_mask_mod: + self.mask_mod_cute, self.mask_mod_flex = get_mask_pair( + config.mask_mod_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + window_size=config.window_size, + ) + else: + self.mask_mod_cute = None + self.mask_mod_flex = None + + self._validate_config() + + def _validate_config(self): + config = self.config + + assert config.headdim <= 256, "headdim must be <= 256" + assert config.headdim_v <= 256, "headdim_v must be <= 256" + assert config.nheads % config.nheads_kv == 0, "nheads must be divisible by nheads_kv" + + alignment = 16 // config.dtype.itemsize + assert config.headdim % alignment == 0, f"headdim must be divisible by {alignment}" + assert config.headdim_v % alignment == 0, f"headdim_v must be divisible by {alignment}" + + # Validate is_local configuration + if config.is_local: + assert config.window_left is not None or config.window_right is not None, ( + "When is_local=True, at least one of window_left or window_right must be set" + ) + assert not config.use_mask_mod, ( + "Cannot use both is_local and use_mask_mod simultaneously" + ) + assert not config.causal, "Cannot use both is_local and causal simultaneously" + + # Validate mask_mod configuration + if config.use_mask_mod and config.mask_mod_name == "sliding_window": + assert config.window_size > 0, ( + "window_size must be positive when using sliding_window mask" + ) + + def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]: + """Generate random sequence lengths and compute cumulative lengths.""" + seqlens = torch.randint( + min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device="cuda" + ) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqlens, dtype=torch.int32, dim=0), + ] + ) + + total_tokens = cu_seqlens[-1].item() + return cu_seqlens, total_tokens + + def _create_tensors(self) -> Dict[str, torch.Tensor]: + config = self.config + device = "cuda" + + if config.use_varlen: + # Set defaults for varlen range + min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2 + max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q + min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2 + max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k + + # Generate cu_seqlens + cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q) + cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k) + + # Varlen shape: (total_tokens, nheads, headdim) + q = torch.randn( + total_q, config.nheads, config.headdim, dtype=config.dtype, device=device + ) + k = torch.randn( + total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device + ) + v = torch.randn( + total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device + ) + out = torch.empty( + total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device + ) + lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + "cu_seqlens_q": cu_seqlens_q.contiguous(), + "cu_seqlens_k": cu_seqlens_k.contiguous(), + } + + if config.verbose: + print(f"Varlen: total_q={total_q}, total_k={total_k}") + print(f"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}") + print(f"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}") + else: + # Standard shape: (batch, seqlen, nheads, headdim) + q = torch.randn( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim, + dtype=config.dtype, + device=device, + ) + k = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim, + dtype=config.dtype, + device=device, + ) + v = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + out = torch.empty( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + lse = torch.empty( + config.batch_size, + config.nheads, + config.seqlen_q, + dtype=torch.float32, + device=device, + ) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + if config.use_learnable_sink: + learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) + + tensors["learnable_sink"] = learnable_sink.contiguous() + + # Compute block sparsity when using mask_mod + if config.use_mask_mod: + if config.mask_mod_name == "document": + doc_id = random_doc_id_tensor( + config.batch_size, config.nheads, config.seqlen_q, device=device + ) + tensors["aux_tensors"] = [doc_id.contiguous()] + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=self.config, + mask_mod_flex=self.mask_mod_flex, + device=device, + cu_seqlens_q=tensors.get("cu_seqlens_q"), + cu_seqlens_k=tensors.get("cu_seqlens_k"), + aux_tensors=tensors.get("aux_tensors"), + ) + + if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): + tensors["block_sparse_tensors"] = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt.contiguous(), + mask_block_idx=mask_idx.contiguous(), + full_block_cnt=full_cnt.contiguous(), + full_block_idx=full_idx.contiguous(), + ) + + if config.verbose: + total_full = full_cnt.sum().item() + total_partial = mask_cnt.sum().item() + + if config.use_varlen: + # Compute max possible blocks across all sequences + max_blocks = 0 + for i in range(config.batch_size): + seq_len_q = ( + tensors["cu_seqlens_q"][i + 1] - tensors["cu_seqlens_q"][i] + ).item() + seq_len_k = ( + tensors["cu_seqlens_k"][i + 1] - tensors["cu_seqlens_k"][i] + ).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + max_blocks += n_blocks_q * n_blocks_k * config.nheads + else: + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size + + skipped = max_blocks - total_full - total_partial + print( + f"Block stats: Full={total_full}, Partial={total_partial}, " + f"Skipped={skipped}/{max_blocks}" + ) + + return tensors + + def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]: + config = self.config + + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[config.dtype] + + qhead_per_kvhead = config.nheads // config.nheads_kv + kernel = FlashAttentionForwardSm90( + cute_dtype, + config.headdim, + config.headdim_v, + qhead_per_kvhead, + is_causal=config.causal, + is_local=config.is_local, + pack_gqa=False, + tile_m=config.tile_m, + tile_n=config.tile_n, + num_stages=config.num_stages, + num_threads=config.num_threads, + intra_wg_overlap=config.intra_wg_overlap, + mma_pv_is_rs=config.mma_pv_is_rs, + mask_mod=self.mask_mod_cute, + Q_in_regs=False, + has_aux_tensors=config.has_aux_tensors, + ) + + softmax_scale = 1.0 / math.sqrt(config.headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Convert tensors to cute + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack(tensors["out"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["out"].ndim - 1 + ) + lse_cute = from_dlpack(tensors["lse"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=tensors["lse"].ndim - 1 + ) + + # Varlen tensors + cu_seqlens_q_cute = ( + from_dlpack(tensors["cu_seqlens_q"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_q" in tensors + else None + ) + cu_seqlens_k_cute = ( + from_dlpack(tensors["cu_seqlens_k"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_k" in tensors + else None + ) + learnable_sink_cute = ( + from_dlpack(tensors["learnable_sink"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "learnable_sink" in tensors + else None + ) + + blocksparse_tensors_cute = ( + to_cute_block_sparse_tensors(tensors["block_sparse_tensors"]) + if "block_sparse_tensors" in tensors + else None + ) + + if "aux_tensors" in tensors: + aux_tensors_cute = [] + for i in range(len(tensors["aux_tensors"])): + buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4) + aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + + else: + aux_tensors_cute = None + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(config.window_left) if config.window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(config.window_right) if config.window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + learnable_sink_cute, + blocksparse_tensors_cute, + aux_tensors_cute, + # None, + ) + + args = ( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, + None, + None, + window_left_cute, + window_right_cute, + learnable_sink_cute, + blocksparse_tensors_cute, + aux_tensors_cute, + # None, + ) + + return compiled, args + + def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: + config = self.config + + # Estimate sparsity for known mask patterns + if config.is_local: + # Local attention with window_left and window_right + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 # +1 for current position + sparsity_ratio = min(1.0, total_window / config.seqlen_k) + elif config.use_mask_mod: + if config.mask_mod_name in ["identity", "identity_partial"]: + sparsity_ratio = 1.0 + elif config.mask_mod_name in ["causal", "block_causal"]: + sparsity_ratio = 0.5 + elif config.mask_mod_name == "sliding_window": + # Use configured window size + sparsity_ratio = min(1.0, config.window_size / config.seqlen_k) + elif config.mask_mod_name == "block_diagonal": + block_size = 64 + num_blocks = (config.seqlen_k + block_size - 1) // block_size + sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 + elif config.mask_mod_name == "document": + vals = tensors["aux_tensors"][0] + val_mask = torch.ones_like(vals, dtype=torch.bool) + val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] + total = torch.where(val_mask, vals.square(), 0).sum() + sparsity_ratio = total / (config.seqlen_q * config.seqlen_k) + else: + sparsity_ratio = 1.0 + elif config.causal: + sparsity_ratio = 0.5 + else: + sparsity_ratio = 1.0 + + if config.use_varlen: + # Compute FLOPs per sequence and sum + total_flops = 0 + cu_q = tensors["cu_seqlens_q"] + cu_k = tensors["cu_seqlens_k"] + for i in range(config.batch_size): + seq_len_q = (cu_q[i + 1] - cu_q[i]).item() + seq_len_k = (cu_k[i + 1] - cu_k[i]).item() + + # Adjust sparsity for local attention in varlen case + if config.is_local: + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 + seq_sparsity = min(1.0, total_window / seq_len_k) + elif config.use_mask_mod and config.mask_mod_name == "sliding_window": + seq_sparsity = min(1.0, config.window_size / seq_len_k) + else: + seq_sparsity = sparsity_ratio + + num_cells = int(seq_len_q * seq_len_k * seq_sparsity) + + if config.headdim == config.headdim_v: + flops_this_seq = 4 * config.nheads * num_cells * config.headdim + else: + flops_this_seq = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + total_flops += flops_this_seq + return total_flops + else: + num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio) + if config.headdim == config.headdim_v: + flops_per_batch = 4 * config.nheads * num_cells * config.headdim + else: + flops_per_batch = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + return flops_per_batch * config.batch_size + + def benchmark(self) -> Dict[str, Any]: + config = self.config + + tensors = self._create_tensors() + compiled_kernel, args = self._compile_kernel(tensors) + + # Warmup + for _ in range(config.warmup_iters): + compiled_kernel(*args) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.benchmark_iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + compiled_kernel(*args) + end.record() + torch.cuda.synchronize() + + times.append(start.elapsed_time(end)) + + times_tensor = torch.tensor(times) + mean_time = times_tensor.mean().item() + std_time = times_tensor.std().item() if len(times) > 1 else 0.0 + + total_flops = self._calculate_flops(tensors) + tflops = total_flops / (mean_time * 1e-3) / 1e12 + + # Bandwidth calculation + bytes_per_element = config.dtype.itemsize + if config.use_varlen: + total_q = tensors["q"].shape[0] + total_k = tensors["k"].shape[0] + memory_accessed = ( + total_q * config.nheads * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim_v * bytes_per_element + + total_q * config.nheads * config.headdim_v * bytes_per_element + ) + else: + memory_accessed = ( + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim_v + * bytes_per_element + + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim_v + * bytes_per_element + ) + bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9 + + results = { + "mean_time_ms": mean_time, + "std_time_ms": std_time, + "tflops": tflops, + "bandwidth_gbps": bandwidth_gbps, + } + + if config.verbose: + self._print_results(results) + + return results + + def _print_results(self, results: Dict[str, Any]): + config = self.config + + # Basic configuration + if config.use_varlen: + print( + f"Shape: B={config.batch_size} (varlen), HD={config.headdim}, " + f"NH={config.nheads}, NKV={config.nheads_kv}" + ) + else: + print( + f"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, " + f"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}" + ) + + # Attention pattern + attn_info = [] + if config.causal: + attn_info.append("causal") + if config.is_local: + window_info = f"local(L={config.window_left},R={config.window_right})" + attn_info.append(window_info) + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + attn_info.append(f"mask_mod={config.mask_mod_name}(w={config.window_size})") + else: + attn_info.append(f"mask_mod={config.mask_mod_name}") + if config.use_varlen: + attn_info.append("varlen") + if attn_info: + print(f"Attention: {', '.join(attn_info)}") + + # Performance metrics + print(f"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms") + print(f"Throughput: {results['tflops']:.2f} TFLOPS") + print(f"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s") + + +if __name__ == "__main__": + B = 2 + config = BenchmarkConfig( + headdim=128, + headdim_v=128, + nheads=16, + nheads_kv=16, + dtype=torch.bfloat16, + batch_size=B, + # batch_size=1, + seqlen_q=8192, + # seqlen_q=128, + seqlen_k=8192, + # seqlen_k=192, + use_varlen=False, + use_mask_mod=False, + mask_mod_name="causal", + window_size=128, # Configurable window size for mask_mod + use_learnable_sink=False, + causal=True, + is_local=False, + verbose=True, + ) + + # Example 2: Base Flash Attention Local + # config = BenchmarkConfig( + # headdim=64, + # headdim_v=64, + # nheads=64, + # nheads_kv=8, + # dtype=torch.bfloat16, + # batch_size=2, + # seqlen_q=8192, + # seqlen_k=8192, + # use_varlen=False, + # use_mask_mod=False, + # causal=False, + # is_local=True, + # window_left=128, # Left window size for base local attention + # window_right=0, # Right window size for base local attention + # verbose=True, + # ) + + benchmark = FlashAttentionBenchmark(config) + results = benchmark.benchmark() diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 663992e99b4..e8709c24f40 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 663992e99b412991eab554b0deb89bb916d40161 +Subproject commit e8709c24f403173ad21a2da907d1347957e324fb diff --git a/csrc/cutlass b/csrc/cutlass index dc4817921ed..b1d6e2c9b33 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit dc4817921edda44a549197ff3a9dcf5df0636e7b +Subproject commit b1d6e2c9b334dfa811e4183dfbd02419249e4b52 diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index dd7a5c3f9b4..c0c0e42176c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -515,7 +515,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. @@ -1340,7 +1340,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, - /*p_ptr=*/nullptr, + /*p_d=*/nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1f016a4a4e6..bb879453680 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -220,7 +220,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -399,4 +403,4 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 68e28355189..4d7d5bd655e 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -272,7 +272,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 3e4422efecd..07cfa9a8f90 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -469,7 +469,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } if (max_seqlen_k > 0) { +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; if (paged_KV) diff --git a/csrc/ft_attention/README.md b/csrc/ft_attention/README.md deleted file mode 100644 index 97feb78cc1c..00000000000 --- a/csrc/ft_attention/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Attention kernel from FasterTransformer - -This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from -FasterTransformer v5.2.1 for benchmarking purpose. - -```sh -cd csrc/ft_attention && pip install . -``` - -As of 2023-09-17, this extension is no longer used in the FlashAttention repo. -FlashAttention now has implemented -[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py) -with all the features of this `ft_attention` kernel (and more). - diff --git a/csrc/ft_attention/cuda_bf16_fallbacks.cuh b/csrc/ft_attention/cuda_bf16_fallbacks.cuh deleted file mode 100644 index f5641f61609..00000000000 --- a/csrc/ft_attention/cuda_bf16_fallbacks.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include - -namespace fastertransformer { - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; -} - -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace fastertransformer diff --git a/csrc/ft_attention/cuda_bf16_wrapper.h b/csrc/ft_attention/cuda_bf16_wrapper.h deleted file mode 100644 index efb6e798730..00000000000 --- a/csrc/ft_attention/cuda_bf16_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu deleted file mode 100644 index 13306f76868..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,149 +0,0 @@ -// Adapted from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include -#include -#include - -#include "decoder_masked_multihead_attention_template.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - auto kernel = mmha::masked_multihead_attention_kernel; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ - kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#undef MMHA_LAUNCH_KERNEL - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h deleted file mode 100644 index 3c79f88b856..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,192 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride_q = 0; - int stride_k = 0; - int stride_v = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - int num_heads_kv = 0; - int num_heads_q_kv_ratio = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - float rotary_base = 0.0f; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; - - const T *rotary_cos = nullptr; - const T *rotary_sin = nullptr; - - const int *nnz_head_idx = nullptr; - int nnz_heads = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 2ae1b2425b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1619 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -#define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ { -}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ { -}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ { -}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - // const int hi = blockIdx.x; - const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x]; - const int hi_kv = hi / params.num_heads_q_kv_ratio; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bhi_kv = bi * params.num_heads_kv + hi_kv; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv; - // The thread in the block. - const int tidx = threadIdx.x; - - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh; - int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh; - int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - int q_offset = q_base_offset + tidx * QK_VEC_SIZE; - int k_offset = k_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = *reinterpret_cast(¶ms.q[q_offset]); - } - } - - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = *reinterpret_cast(¶ms.k[k_offset]); - } - } - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : - q_bias; - - Qk_vec k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - else { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (has_beams) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); - } - else { - k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); - } - } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE])); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti - first_step]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - const auto v_offset = v_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = *reinterpret_cast(¶ms.v[v_offset]); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - if (hi % params.num_heads_q_kv_ratio == 0) { - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; - } - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength - first_step], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h deleted file mode 100644 index 98875aba9b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h +++ /dev/null @@ -1,2017 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include - -using namespace fastertransformer; - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct num_elems; -template<> -struct num_elems { - static constexpr int value = 1; -}; -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -#ifdef ENABLE_BF16 -template<> -struct num_elems<__nv_bfloat162> { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct packed_type; -template -struct packed_type { - using type = T; -}; -template<> -struct packed_type { - using type = int16_t; -}; -template<> -struct packed_type { - using type = int32_t; -}; -template<> -struct packed_type { - using type = int64_t; -}; - -template<> -struct packed_type { - using type = float2; -}; -template<> -struct packed_type { - using type = float4; -}; -template<> -struct packed_type { - using type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, uint16_t b) -{ - return a + half_to_float(b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float add(float a, __nv_bfloat16 b) -{ - return a + __bfloat162float(b); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(float a, Float8_ b) -{ - Float8_ c; - c.x = make_float2(a * b.x.x, a * b.x.y); - c.y = make_float2(a * b.y.x, a * b.y.y); - c.z = make_float2(a * b.z.x, a * b.z.y); - c.w = make_float2(a * b.w.x, a * b.w.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, float b) -{ - return half_to_float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, float b) -{ - return __bfloat162float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base) -{ - const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); - return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)}; -} - -inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) -{ - float2 rot_v; - rot_v.x = coef.x * v.x - coef.y * v.y; - rot_v.y = coef.x * v.y + coef.y * v.x; - return rot_v; -} - -inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) -{ - float2 fv = half2_to_float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return float2_to_half2(rot_fv); -} - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) -{ - float2 fv = bf1622float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); -} -#endif - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])}; -} - -// fp16 is special because we use uint16_t for reading the data, for backward compatibility. -template <> -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(reinterpret_cast(rotary_cos)[zid / 2]), - float(reinterpret_cast(rotary_sin)[zid / 2])}; -} - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - - vec = tmp_3.u32x2; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - tmp_3.u16[4] = tmp_1.u16[2]; - tmp_3.u16[5] = tmp_2.u16[2]; - tmp_3.u16[6] = tmp_1.u16[3]; - tmp_3.u16[7] = tmp_2.u16[3]; - - vec = tmp_3.u32x4; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - __nv_bfloat16 bf16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; -} - -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - __nv_bfloat16 bf16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; - vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; - vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; -} -#endif // ENABLE_BF16 - -template<> -__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.z = smem[transpose_idx + 1]; - vec.y = smem[smem_pitch + transpose_idx]; - vec.w = smem[smem_pitch + transpose_idx + 1]; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} -#endif - -template<> -__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} - -template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u32x4 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - tmp_1.u16[2] = tmp_3.u16[4]; - tmp_2.u16[2] = tmp_3.u16[5]; - tmp_1.u16[3] = tmp_3.u16[6]; - tmp_2.u16[3] = tmp_3.u16[7]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u32x2 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = vec; - - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -template<> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[transpose_idx + 1] = vec.z; - smem[smem_pitch + transpose_idx] = vec.y; - smem[smem_pitch + transpose_idx + 1] = vec.w; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - - tmp.u32 = vec; - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} -#endif - -template<> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -} // namespace mmha diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp deleted file mode 100644 index 886da9729ba..00000000000 --- a/csrc/ft_attention/ft_attention.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" -#include - - -#include "decoder_masked_multihead_attention.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ - if (TYPE == at::ScalarType::Half) { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::BFloat16) { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::Float) { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ - } - -template -void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -void cross_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -struct SATypeConverter { - using Type = T; -}; - -template<> -struct SATypeConverter { - using Type = uint16_t; -}; - -template<> -struct SATypeConverter { - using Type = __nv_bfloat16; -}; - -template -void set_params(Masked_multihead_attention_params ¶ms, - const size_t batch_size, - const size_t nheads, - const size_t nheads_kv, - const size_t memory_max_seqlen, - const size_t headdim, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - const bool neox_rotary_style, - const int q_batch_stride, - const int k_batch_stride, - const int v_batch_stride, - const int nnz_heads, - T *q_ptr, - T *k_ptr, - T *v_ptr, - T *k_cache_ptr, - T *v_cache_ptr, - int *length_per_sample, - T *rotary_cos, - T *rotary_sin, - T *out_ptr, - int *nnz_head_idx) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.q_bias = nullptr; - params.k_bias = nullptr; - params.v_bias = nullptr; - params.k_cache = k_cache_ptr; - params.v_cache = v_cache_ptr; - params.out = out_ptr; - params.cache_indir = nullptr; - params.stride_q = q_batch_stride; - params.stride_k = k_batch_stride; - params.stride_v = v_batch_stride; - params.batch_size = batch_size; - params.beam_width = 1; - params.memory_max_len = memory_max_seqlen; - params.num_heads = nheads; - params.num_heads_kv = nheads_kv; - params.num_heads_q_kv_ratio = nheads / nheads_kv; - params.nnz_heads = nnz_heads; - params.hidden_size_per_head = headdim; - params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_base = rotary_base; - params.neox_rotary_style = neox_rotary_style; - params.timestep = timestep; - params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); - params.total_padding_tokens = nullptr; - params.masked_tokens = nullptr; - params.prefix_prompt_lengths = nullptr; - params.max_prefix_prompt_length = 0; - params.relative_attention_bias = nullptr; - params.relative_attention_bias_stride = 0; - params.cross_attention_out = nullptr; - params.max_decoder_seq_len = 0; - params.is_return_cross_attentions = false; - params.finished = nullptr; - params.memory_length_per_sample = nullptr; - params.length_per_sample = length_per_sample; - params.rotary_cos = rotary_cos; - params.rotary_sin = rotary_sin; - params.nnz_head_idx = nnz_head_idx; -} - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - std::optional length_per_sample_, - std::optional rotary_cos_, - std::optional rotary_sin_, - std::optional nnz_head_idx_, - const int timestep, - int rotary_embedding_dim = 0, - const float rotary_base = 10000.0f, - const bool neox_rotary_style=true) { - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); - int batch_size = v_cache.size(0); - int nheads = q.size(1); - int nheads_kv = v_cache.size(1); - int memory_max_seqlen = v_cache.size(2); - int headdim = v_cache.size(3); - auto input_type = q.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - - CHECK_SHAPE(q, batch_size, nheads, headdim); - CHECK_SHAPE(k, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); - // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; - CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); - TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); - TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); - TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); - CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); - - TORCH_CHECK(q.scalar_type() == input_type); - TORCH_CHECK(k.scalar_type() == input_type); - TORCH_CHECK(v.scalar_type() == input_type); - TORCH_CHECK(k_cache.scalar_type() == input_type); - TORCH_CHECK(v_cache.scalar_type() == input_type); - - if (length_per_sample_.has_value()) { - auto length_per_sample = length_per_sample_.value(); - CHECK_DEVICE(length_per_sample); - CHECK_SHAPE(length_per_sample, batch_size); - CHECK_CONTIGUOUS(length_per_sample); - TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); - rotary_embedding_dim = rotary_cos.size(-1) * 2; - CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_cos); - TORCH_CHECK(rotary_cos.scalar_type() == input_type); - - TORCH_CHECK(rotary_sin_.has_value()); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); - CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_sin); - TORCH_CHECK(rotary_sin.scalar_type() == input_type); - } - - if (nnz_head_idx_.has_value()) { - auto nnz_head_idx = nnz_head_idx_.value(); - CHECK_DEVICE(nnz_head_idx); - int nnz_heads = nnz_head_idx.size(0); - CHECK_SHAPE(nnz_head_idx, nnz_heads); - CHECK_CONTIGUOUS(nnz_head_idx); - TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32); - } - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - torch::Tensor out = torch::empty_like(q); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { - using DataType = typename SATypeConverter::Type; - Masked_multihead_attention_params params; - set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, rotary_base, neox_rotary_style, - q.stride(0), k.stride(0), v.stride(0), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - length_per_sample_.has_value() - ? length_per_sample_.value().data_ptr() : nullptr, - rotary_cos_.has_value() - ? reinterpret_cast(rotary_cos_.value().data_ptr()) : nullptr, - rotary_sin_.has_value() - ? reinterpret_cast(rotary_sin_.value().data_ptr()) : nullptr, - reinterpret_cast(out.data_ptr()), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr() : nullptr - ); - auto stream = at::cuda::getCurrentCUDAStream(); - masked_multihead_attention(params, stream); - }); - return out; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_query_attention", &single_query_attention, "Attention with a single query", - py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("rotary_cos_"), - py::arg("rotary_sin_"), py::arg("nnz_head_idx_"), - py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); -} diff --git a/csrc/ft_attention/setup.py b/csrc/ft_attention/setup.py deleted file mode 100644 index fa385ad768c..00000000000 --- a/csrc/ft_attention/setup.py +++ /dev/null @@ -1,153 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--ft_attention") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("ft_attention is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="ft_attention", - sources=[ - "ft_attention.cpp", - "decoder_masked_multihead_attention.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-DENABLE_BF16", # TODO - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="ft_attention", - version="0.1", - description="Attention for single query from FasterTransformer", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/fused_softmax/fused_softmax.cpp b/csrc/fused_softmax/fused_softmax.cpp deleted file mode 100644 index 2aaed913314..00000000000 --- a/csrc/fused_softmax/fused_softmax.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - - m.def("scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); - - m.def("scaled_masked_softmax_get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); - - m.def("scaled_upper_triang_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("scaled_upper_triang_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/csrc/fused_softmax/scaled_masked_softmax.h b/csrc/fused_softmax/scaled_masked_softmax.h deleted file mode 100644 index 14b9f6e4242..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax.h +++ /dev/null @@ -1,528 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 12: // 4096 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 13: // 8192 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_masked_softmax_cuda.cu deleted file mode 100644 index a08e752699c..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,121 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches - ); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - void* input_grads_ptr = static_cast(input_grads.data_ptr()); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(input_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads - ); - ); - return input_grads; -} -} -} -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h b/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index 21e93fb313a..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,529 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 79ec30be364..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,98 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 8192); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/fused_softmax/setup.py b/csrc/fused_softmax/setup.py deleted file mode 100644 index 9c1c6ed76e9..00000000000 --- a/csrc/fused_softmax/setup.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron -# We add the case where seqlen = 4k and seqlen = 8k -import os -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") - -setup( - name='fused_softmax_lib', - ext_modules=[ - CUDAExtension( - name='fused_softmax_lib', - sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ) - ], - cmdclass={ - 'build_ext': BuildExtension -}) diff --git a/csrc/fused_softmax/type_shim.h b/csrc/fused_softmax/type_shim.h deleted file mode 100644 index 815ec7ec889..00000000000 --- a/csrc/fused_softmax/type_shim.h +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ -switch(TYPE) \ -{ \ -case at::ScalarType::Half: \ - { \ -using scalar_t = at::Half; \ -__VA_ARGS__; \ -break; \ - } \ -case at::ScalarType::BFloat16: \ - { \ -using scalar_t = at::BFloat16; \ -__VA_ARGS__; \ -break; \ - } \ -default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -} diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp deleted file mode 100644 index 640eea423ac..00000000000 --- a/csrc/rotary/rotary.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj); - -void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - CHECK_DEVICE(x1); CHECK_DEVICE(x2); - CHECK_DEVICE(cos); CHECK_DEVICE(sin); - CHECK_DEVICE(out1); CHECK_DEVICE(out1); - TORCH_CHECK(x1.dtype() == x2.dtype()); - TORCH_CHECK(cos.dtype() == sin.dtype()); - TORCH_CHECK(out1.dtype() == out2.dtype()); - TORCH_CHECK(x1.dtype() == cos.dtype()); - TORCH_CHECK(x1.dtype() == out1.dtype()); - TORCH_CHECK(x1.sizes() == x2.sizes()); - TORCH_CHECK(cos.sizes() == sin.sizes()); - TORCH_CHECK(out1.sizes() == out2.sizes()); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{x1.device()}; - - apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); -} diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu deleted file mode 100644 index 2dd0ff3f6e2..00000000000 --- a/csrc/rotary/rotary_cuda.cu +++ /dev/null @@ -1,45 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - auto iter = at::TensorIteratorConfig() - .add_output(out1) - .add_output(out2) - .add_input(x1) - .add_input(x2) - .add_input(cos) - .add_input(sin) - .check_all_same_dtype(false) - .promote_inputs_to_common_dtype(false) - .build(); - - if (!conj) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); - scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); - scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } -} \ No newline at end of file diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py deleted file mode 100644 index 24d328d9c6a..00000000000 --- a/csrc/rotary/setup.py +++ /dev/null @@ -1,126 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -raise_if_cuda_home_none("rotary_emb") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - 'rotary_emb', [ - 'rotary.cpp', - 'rotary_cuda.cu', - ], - extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], - 'nvcc': append_nvcc_threads([ - '-O3', '--use_fast_math', '--expt-extended-lambda' - ] + cc_flag) - } - ) -) - -setup( - name="rotary_emb", - version="0.1", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md deleted file mode 100644 index 1bc90fdab77..00000000000 --- a/csrc/xentropy/README.md +++ /dev/null @@ -1,14 +0,0 @@ -This CUDA extension implements optimized cross-entropy loss, adapted from Apex's -[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). -We make it work for bfloat16 and support in-place backward to save memory. - -It has only been tested on A100s. - -```sh -cd csrc/xentropy && pip install . -``` - -As of 2023-09-15, this extension is no longer used in the FlashAttention repo. -We've instead switched to a Triton-based -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). -See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details. diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp deleted file mode 100644 index 41a783fd0fc..00000000000 --- a/csrc/xentropy/interface.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include - -// CUDA forward declarations -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes=-1) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - CHECK_INPUT(input); - CHECK_INPUT(labels); - - return softmax_xentropy_cuda(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes=-1) { - CHECK_INPUT(grad_loss); - CHECK_INPUT(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, - smoothing, inplace, total_classes); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); -} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py deleted file mode 100644 index 5079b4f3847..00000000000 --- a/csrc/xentropy/setup.py +++ /dev/null @@ -1,139 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--xentropy") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("xentropy is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="xentropy_cuda_lib", - sources=[ - "interface.cpp", - "xentropy_kernel.cu" - ], - extra_compile_args={ - "cxx": ["-O3"] + generator_flag, - "nvcc": append_nvcc_threads( - ["-O3"] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="xentropy_cuda_lib", - version="0.1", - description="Cross-entropy loss", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu deleted file mode 100644 index 66aab0007ba..00000000000 --- a/csrc/xentropy/xentropy_kernel.cu +++ /dev/null @@ -1,758 +0,0 @@ -// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu -// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). -/** - * From PyTorch: - * - * Copyright (c) 2016- Facebook, Inc (Adam Paszke) - * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - * Copyright (c) 2011-2013 NYU (Clement Farabet) - * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - * - * From Caffe2: - * - * Copyright (c) 2016-present, Facebook Inc. All rights reserved. - * - * All contributions by Facebook: - * Copyright (c) 2016 Facebook Inc. - * - * All contributions by Google: - * Copyright (c) 2015 Google Inc. - * All rights reserved. - * - * All contributions by Yangqing Jia: - * Copyright (c) 2015 Yangqing Jia - * All rights reserved. - * - * All contributions from Caffe: - * Copyright(c) 2013, 2014, 2015, the respective contributors - * All rights reserved. - * - * All other contributions: - * Copyright(c) 2015, 2016 the respective contributors - * All rights reserved. - * - * Caffe2 uses a copyright model similar to Caffe: each contributor holds - * copyright over their contributions to Caffe2. The project versioning records - * all such contribution and copyright details. If a contributor wants to further - * mark their specific copyright on a particular contribution, they should - * indicate their copyright solely in the commit message of the change when it is - * committed. - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - * and IDIAP Research Institute nor the names of its contributors may be - * used to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#include -#include -#include - -#include -#include - -// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -// #else -// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ -// switch(TYPE) \ -// { \ -// case at::ScalarType::Float: \ -// { \ -// using scalar_t_##LEVEL = float; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// case at::ScalarType::Half: \ -// { \ -// using scalar_t_##LEVEL = at::Half; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// default: \ -// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -// } -// #endif - -#define ALIGN_BYTES 16 - -using Tensor = at::Tensor; -using TensorList = at::TensorList; -using ScalarType = at::ScalarType; -using at::acc_type; - -template -struct LogSoftMaxForwardEpilogue { - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} - - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} - - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } - - const AccumT logsum; -}; - -template -struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} - - __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { - return static_cast(gradOutput - std::exp(static_cast(output)) * sum); - } - - const AccumT sum; -}; - - - -const int max_threads = 1024; - -inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { - uint64_t block_size = 1; - uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; - // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); - return dim3(block_size); -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// Regular kernel (fast when dim_size is large; requires inner_size == 1) -//////////////////////////////////////////////////////////////////////////////// - - -template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } -}; - -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } -}; - -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} - - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } - - const AccumT max_k; -}; - -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val; - - __syncthreads(); - - AccumT warpVal = defaultVal; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); - } - __syncwarp(mask); - smem[lane] = warpVal; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal = defaultVal; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal = r(blockVal, smem[i]); - } - smem[0] = blockVal; - } - - // Sync and broadcast - __syncthreads(); - return smem[0]; -} - -template class Reduction1, template class Reduction2, typename AccumT> -__device__ __forceinline__ void -blockReduce(AccumT* smem, - AccumT* reducVal1, - AccumT val1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - AccumT val2, - const Reduction2& r2, - AccumT defaultVal2) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val1; - smem[blockDim.x + threadIdx.x] = val2; - - __syncthreads(); - - AccumT warpVal1 = defaultVal1; - AccumT warpVal2 = defaultVal2; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal1 = r1(warpVal1, smem[lane * 32 + i]); - warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); - } - __syncwarp(mask); - smem[lane] = warpVal1; - smem[lane + blockDim.x] = warpVal2; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal1 = defaultVal1; - AccumT blockVal2 = defaultVal2; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal1 = r1(blockVal1, smem[i]); - blockVal2 = r2(blockVal2, smem[i + blockDim.x]); - } - smem[0] = blockVal1; - smem[blockDim.x] = blockVal2; - } - - // Sync and broadcast - __syncthreads(); - *reducVal1 = smem[0]; - *reducVal2 = smem[blockDim.x]; - __syncthreads(); -} - -template class Reduction, int ILP, typename T, typename AccumT> -__device__ __forceinline__ AccumT -ilpReduce(int shift, - T* data, - int size, - const Reduction& r, - AccumT defaultVal) -{ - typedef typename std::aligned_storage::type LoadT; - AccumT threadVal = defaultVal; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal = r(threadVal, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal = r(threadVal, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) - threadVal = r(threadVal, data[offset]); - - return threadVal; -} - -template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> -__device__ __forceinline__ void -ilpReduce(int shift, - T* data, - int size, - AccumT* reducVal1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - const Reduction2& r2, - AccumT defaultVal2) -{ - typedef typename std::aligned_storage::type LoadT; - - AccumT threadVal1 = defaultVal1; - AccumT threadVal2 = defaultVal2; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal1 = r1(threadVal1, v[j]); - threadVal2 = r2(threadVal2, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) { - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - - *reducVal1 = threadVal1; - *reducVal2 = threadVal2; -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyForward( - accscalar_t *losses, - outscalar_t *max_log_sum_exp, - scalar_t *input, - int64_t *labels, - int64_t classes, - const float smoothing, - const int total_classes) -{ - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - // forward pointers to batch[blockIdx.x] - // each block handles a sample in the mini-batch - input += blockIdx.x * classes; - //output += blockIdx.x * classes; - const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); - - int64_t label = labels[blockIdx.x]; - - // find the max and sum - accscalar_t threadMax, threadSum, max_k, sum_k; - ilpReduce( - shift, input, classes, - &threadMax, MaxFloat(), - -at::numeric_limits::max(), - &threadSum, AddFloat(), - static_cast(0)); - - blockReduce( - sdata, - &max_k, threadMax, Max(), - -at::numeric_limits::max(), - &sum_k, threadSum, Add(), - static_cast(0)); - - accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); - accscalar_t sumAll = blockReduce( - sdata, threadExp, Add(), static_cast(0)); - - Epilogue epilogue(max_k, sumAll); - - // calculate per element loss with label smoothing - // reserve max + log_sum_exp for bprop - if (threadIdx.x == 0) { - accscalar_t lse = max_k + std::log(sumAll); - accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; - losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); - max_log_sum_exp[blockIdx.x] = lse; - } -} - -template -__device__ __forceinline__ void -apply(scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - int last = classes % (ILP * blockDim.x); - - for (; offset < classes - last; offset += blockDim.x * ILP) { - accscalar_t tmpLogits[ILP]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); - } - -#pragma unroll - for (int j = 0; j < ILP; ++j) - gradInput[offset + j * blockDim.x] = tmpGradOutput * ( - std::exp(tmpLogits[j] - coeff) - static_cast( - (offset + j * blockDim.x == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast((offset == label) ? 1 : 0) * - smooth_positives - smooth_negatives); -} - - -template -__device__ __forceinline__ void -aligned_apply(int shift, - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - logits -= shift; - gradInput -= shift; - classes += shift; - if(threadIdx.x >= shift){ - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - classes -= blockDim.x; - gradInput += blockDim.x; - logits += blockDim.x; - shift -= blockDim.x; - } - - int last = classes % (ILP * blockDim.x); - - typedef typename std::aligned_storage::type LoadT; - // input - scalar_t v[ILP]; - LoadT* value = reinterpret_cast(&v); - // output - scalar_t r[ILP]; - LoadT* result = reinterpret_cast(&r); - - for (; offset * ILP < (classes - last); offset += blockDim.x) { - *value = reinterpret_cast(logits)[offset]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - r[j] = tmpGradOutput * (std::exp( - static_cast(v[j]) - coeff) - - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - reinterpret_cast(gradInput)[offset] = *result; - } - - offset = classes - last + threadIdx.x; - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyBackward( - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - gradInput += blockIdx.x * classes; - logits += blockIdx.x * classes; - - // Do vectorized load/store when input/output have same alignment - const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); - const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); - if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - -} - -template class Epilogue> -std::vector host_softmax_xentropy( - const Tensor & input_, - const Tensor & labels_, - const float smoothing, - const int total_classes) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{input_.device()}; - - auto input = input_.contiguous(); - Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); - Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); - - const int64_t dim = 1; - int64_t outer_size = 1; - int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - using namespace at; - DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", - using accscalar_t = at::acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyForward - <<>>( - losses.data_ptr(), max_log_sum_exp.data_ptr(), - input.data_ptr(), labels_.data_ptr(), - dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - - std::vector ret = {losses, max_log_sum_exp}; - return ret; -} - -template class Epilogue> -Tensor host_softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits_, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - bool inplace, - const int total_classes) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{grad_loss.device()}; - - const int64_t dim = 1; - Tensor gI = inplace ? logits_ : at::empty_like(logits_); - if (grad_loss.numel() == 0) { - return gI; - } - - auto grad = grad_loss.contiguous(); - auto logits = logits_.contiguous(); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - if (grad.dim() == 0) grad = grad.view(1); - - AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); - AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); - - int64_t outer_size = 1; - int64_t dim_size = logits.size(dim); - int64_t inner_size = 1; - for (int64_t i = 0; i < dim; ++i) - outer_size *= logits.size(i); - for (int64_t i = dim + 1; i < logits.dim(); ++i) - inner_size *= logits.size(i); - // See descriptions of kernels above. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", - using accscalar_t = acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyBackward - <<>>( - gI.data_ptr(), logits.data_ptr(), - max_log_sum_exp.data_ptr(), - grad.data_ptr(), labels.data_ptr(), - smoothing, dim_size, total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - return gI; -} - -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ - return host_softmax_xentropy(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes) { - AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); -} diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 9ef52f504bb..4a8a7c33f46 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.8.0.post2" +__version__ = "2.8.3" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/cute/.flake8 b/flash_attn/cute/.flake8 new file mode 100644 index 00000000000..bae5b85c002 --- /dev/null +++ b/flash_attn/cute/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +# W503: line break before binary operator +ignore = E731, E741, F841, W503 diff --git a/flash_attn/cute/AUTHORS b/flash_attn/cute/AUTHORS new file mode 100644 index 00000000000..bc3991c676d --- /dev/null +++ b/flash_attn/cute/AUTHORS @@ -0,0 +1,5 @@ +Tri Dao, tri@tridao.me +Jay Shah +Ted Zadouri +Markus Hoehnerbach +Vijay Thakkar \ No newline at end of file diff --git a/flash_attn/cute/LICENSE b/flash_attn/cute/LICENSE new file mode 100644 index 00000000000..5860e4b33f3 --- /dev/null +++ b/flash_attn/cute/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py new file mode 100644 index 00000000000..fbbfc14050e --- /dev/null +++ b/flash_attn/cute/__init__.py @@ -0,0 +1,21 @@ +"""Flash Attention CUTE (CUDA Template Engine) implementation.""" + +__version__ = "0.1.0" + +import cutlass.cute as cute + +from .interface import ( + flash_attn_func, + flash_attn_varlen_func, +) + +from flash_attn.cute.cute_dsl_utils import cute_compile_patched + +# Patch cute.compile to optionally dump SASS +cute.compile = cute_compile_patched + + +__all__ = [ + "flash_attn_func", + "flash_attn_varlen_func", +] diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py index 41238edc365..e3072d8ce85 100644 --- a/flash_attn/cute/ampere_helpers.py +++ b/flash_attn/cute/ampere_helpers.py @@ -6,18 +6,32 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: - dtype_byte = dtype.width // 8 - bytes_per_row = k_dim * dtype_byte - smem_k_block_size = (128 if bytes_per_row % 128 == 0 else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))) // dtype_byte - swizzle_bits = 4 if smem_k_block_size == 128 else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + dtype_byte = cutlass.const_expr(dtype.width // 8) + bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + cutlass.const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), 0, - cute.make_ordered_layout((8 if k_dim % 32 == 0 else 16, smem_k_block_size), order=(1, 0)), + cute.make_ordered_layout( + (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) + ), ) +@cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -32,29 +46,44 @@ def gemm( B_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: - if swap_AB: + if cutlass.const_expr(swap_AB): gemm( - tiled_mma, acc, tCrB, tCrA, tCsB, tCsA, smem_thr_copy_B, smem_thr_copy_A, hook_fn, - A_in_regs=B_in_regs, B_in_regs=A_in_regs, swap_AB=False + tiled_mma, + acc, + tCrB, + tCrA, + tCsB, + tCsA, + smem_thr_copy_B, + smem_thr_copy_A, + hook_fn, + A_in_regs=B_in_regs, + B_in_regs=A_in_regs, + swap_AB=False, ) else: tCrA_copy_view = smem_thr_copy_A.retile(tCrA) tCrB_copy_view = smem_thr_copy_B.retile(tCrB) - if not A_in_regs: + if cutlass.const_expr(not A_in_regs): cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) - if not B_in_regs: + if cutlass.const_expr(not B_in_regs): cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) - for k in range(cute.size(tCsA.shape[2])): + for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): if k < cute.size(tCsA.shape[2]) - 1: - if not A_in_regs: - cute.copy(smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]) - if not B_in_regs: - cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) + if cutlass.const_expr(not A_in_regs): + cute.copy( + smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] + ) + if cutlass.const_expr(not B_in_regs): + cute.copy( + smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] + ) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): hook_fn() +@cute.jit def gemm_rs( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -66,8 +95,8 @@ def gemm_rs( ) -> None: tCrB_copy_view = smem_thr_copy_B.retile(tCrB) cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) - for k in range(cute.size(tCrA.shape[2])): - if k < cute.size(tCrA.shape[2]) - 1: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): diff --git a/flash_attn/cute/barrier.py b/flash_attn/cute/barrier.py new file mode 100644 index 00000000000..c999b180167 --- /dev/null +++ b/flash_attn/cute/barrier.py @@ -0,0 +1,71 @@ +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/flash_attn/cute/benchmark.py b/flash_attn/cute/benchmark.py new file mode 100644 index 00000000000..9a7820e7b0c --- /dev/null +++ b/flash_attn/cute/benchmark.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, Tri Dao. +"""Useful functions for writing test code.""" + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(*inputs, y=y, grad=grad)", + globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward + Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(grad, *inputs, **kwinputs)", + globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_fwd_bwd( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def benchmark_all( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_combined( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def pytorch_profiler( + fn, + *inputs, + trace_filename=None, + backward=False, + amp=False, + amp_dtype=torch.float16, + cpu=False, + verbose=True, + **kwinputs, +): + """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" + if backward: + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) + for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + # Backward should be done outside autocast + if backward: + out.backward(g, retain_graph=True) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ + torch.profiler.ProfilerActivity.CUDA + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + if backward: + out.backward(g, retain_graph=True) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) + if verbose: + print(f"{desc} max memory: {mem}GB") + torch.cuda.empty_cache() + return mem diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py new file mode 100644 index 00000000000..e2ff2ccc9ae --- /dev/null +++ b/flash_attn/cute/blackwell_helpers.py @@ -0,0 +1,753 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +import flash_attn.cute.mma_sm100_desc as sm100_desc +from flash_attn.cute.utils import parse_swizzle_from_pointer + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> cute.TiledMma: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + return tiled_mma + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) + if const_expr(mbar_ptr is None) + else cute.size(tCrA.shape[2]) // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index d91c15c54bb..be13e70f892 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -1,61 +1,108 @@ -from typing import Tuple +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +from typing import Tuple, Optional +from dataclasses import dataclass import cutlass import cutlass.cute as cute +from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK +@dataclass(frozen=True) class BlockInfo: - - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - n_block_size: cutlass.Constexpr[int], - is_causal: cutlass.Constexpr[bool], - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if we're doing PackGQA - *, - loc=None, - ip=None - ): - self.m_block_size: cutlass.Constexpr[int] = m_block_size - self.n_block_size: cutlass.Constexpr[int] = n_block_size - self.is_causal: cutlass.Constexpr[bool] = is_causal - self.qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = qhead_per_kvhead_packgqa - self._loc = loc + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + is_local: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit def get_n_block_min_max( - self, seqlen_info: SeqlenInfo, m_block: cutlass.Int32 - ) -> Tuple[cutlass.Int32, cutlass.Int32]: - n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) - n_block_min = 0 - if cutlass.const_expr(self.is_causal): - m_idx_max = (m_block + 1) * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): - m_idx_max = (m_idx_max - 1) // self.qhead_per_kvhead_packgqa + 1 + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: cutlass.Int32 = 0, + num_splits: cutlass.Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx - n_block_max = min(cute.ceil_div(n_idx_right, self.n_block_size), n_block_max) + n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) + n_block_min = 0 + if const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) + if cutlass.const_expr(self.is_split_kv): + num_n_blocks_per_split = ( + cutlass.Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) return n_block_min, n_block_max + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + m_block_min = 0 + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right + m_block_min = max(m_block_min, m_idx_right // self.tile_m) + if const_expr(self.is_local and self.window_size_left is not None): + n_idx_max = (n_block + 1) * self.tile_n + m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_left = m_idx + self.window_size_left + m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) + return m_block_min, m_block_max + @cute.jit def get_n_block_min_causal_local_mask( self, - seqlen_info: SeqlenInfo, - m_block: cutlass.Int32, - n_block_min: cutlass.Int32, - ) -> cutlass.Int32: - m_idx_min = m_block * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_min: Int32, + ) -> Int32: + """If we have separate iterations with causal or local masking at the start, where do we stop""" + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx - return cutlass.max(n_block_min, n_idx_right // self.n_block_size) - - def __extract_mlir_values__(self): - # We just create a dummy value. Otherwise unpack_to_irvalue in cutlass.py will complain - return [cutlass.Int32(0).ir_value()] + n_idx_right = ( + n_idx + if const_expr(not self.is_local or self.window_size_right is None) + else n_idx + self.window_size_right + ) + return cutlass.max(n_block_min, n_idx_right // self.tile_n) - def __new_from_mlir_values__(self, values): - return BlockInfo(self.m_block_size, self.n_block_size, self.is_causal, self.qhead_per_kvhead_packgqa, loc=self._loc) + @cute.jit + def get_n_block_min_before_local_mask( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_min: Int32, + ) -> Int32: + """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" + if const_expr(not self.is_local or self.window_size_left is None): + return n_block_min + else: + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py new file mode 100644 index 00000000000..e814d6aa458 --- /dev/null +++ b/flash_attn/cute/block_sparse_utils.py @@ -0,0 +1,816 @@ +""" +Block-sparse runtime utilities for CUTE DSL kernels. + +This module contains runtime execution functions for block-sparse attention kernels. +These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. +""" + +from typing import Callable, Optional +from functools import partial +import math +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr + +# Import data structures from block_sparsity +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute import utils + + +@cute.jit +def load_block_list( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + first_block_preloaded: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, +): + """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. + for the intra_wg_overlap case, we overlap the loads of K and V. And this + means we need to pipeline the last V load from the partial block case, + with the loads for the full blocks. Set first_block_preloaded when the + caller has already issued the first K load for the list. + + Note: + we iterate along the block_n indices in reverse. + + Returns: + Updated kv_producer_state after processing the block list. + + """ + if block_count > 0: + if const_expr(not intra_wg_overlap): + # Peel first iteration: the first block may need to load Q alongside K, + # Parameters are already Constexpr, so no need to wrap in const_expr() + n_block_first = block_indices[block_count - 1] + extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_first, producer_state=kv_producer_state) + kv_producer_state.advance() + + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + n_block_first = block_indices[block_count - 1] + if const_expr(not first_block_preloaded): + extra_tx = ( + tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + ) + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + + for idx in cutlass.range(block_count - 1, unroll=1): + n_block_prev = block_indices[block_count - 1 - idx] + n_block = block_indices[block_count - 2 - idx] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + + return kv_producer_state + + +@cute.jit +def finish_overlap_v_load( + block_indices: cute.Tensor, + block_count, + load_V, + pipeline_v, + kv_producer_state, +): + """Load the final V block after overlapped K/V loads.""" + if block_count > 0: + n_block_last = block_indices[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_last, producer_state=kv_producer_state) + kv_producer_state.advance() + + return kv_producer_state + + +@cute.jit +def produce_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, +): + """Iterate over the mask and full block lists for a single tile. + + The masked (partial) list may leave the last V load pending when intra-warp-group + overlap is enabled. The first full block must consume that pending V while + issuing its own K load on the next pipeline stage. + + In the intra-wg-overlap path, the last masked block leaves its V copy in flight + while we advance the producer state to start the next full K. Either the full list + overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. + + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + if mask_empty: + # No masked blocks: the full list owns the initial Q+K load. + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Masked blocks present: load Q together with the first masked K so consumers can + # start immediately. When overlap is disabled this fully drains the list. + kv_producer_state = load_block_list( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if full_empty: + if const_expr(intra_wg_overlap): + kv_producer_state = finish_overlap_v_load( + curr_mask_block_idx, + curr_mask_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + if const_expr(intra_wg_overlap): + # Bridge the masked list to the full list by overlapping the pending masked V + # with the first full K load. + n_block_mask_last = curr_mask_block_idx[0] + n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=True, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Non-overlap path with both lists: run the full list normally (skipping the Q + # reload because the masked list already issued it). + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + return kv_producer_state + + +@cute.jit +def consume_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + mask_mod, + fastdiv_mods, + intra_wg_overlap: cutlass.Constexpr, + warp_scheduler_barrier_sync: Callable, + warp_scheduler_barrier_arrive: Callable, +): + """Consume the mask and full block lists for a single tile on the consumer side. + + Mirrors `produce_block_sparse_loads` so that the consumer pipeline + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 + + if const_expr(not intra_wg_overlap): + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + if curr_full_block_cnt == 0: + warp_scheduler_barrier_arrive() + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + is_first_n_block=False, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + warp_scheduler_barrier_arrive() + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + kv_consumer_state = process_first_half_block( + n_block=mask_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + kv_consumer_state = process_first_half_block( + n_block=full_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_mask_block_cnt + curr_full_block_cnt > 0: + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + + return kv_consumer_state, O_should_accumulate, processed_any + + +@cute.jit +def load_block_list_sm100( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + m_block, + q_stage: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, +): + """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + if block_count > 0: + # First iteration: load Q alongside K if requested + n_block_first = block_indices[block_count - 1] + + if const_expr(load_q_with_first): + # SM100 loads Q0 and optionally Q1 + load_Q(block=q_stage * m_block + 0, stage=0) + if const_expr(q_stage == 2): + load_Q(block=q_stage * m_block + 1, stage=1) + + # SM100 doesn't use producer_acquire for pipeline_kv in load path + # The pipeline barriers are handled inside load_KV + load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + # Remaining blocks + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + return kv_producer_state + + +# SM100-specific tile processor using SM100 helpers +@cute.jit +def produce_block_sparse_loads_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + q_stage: cutlass.Constexpr, + q_producer_phase: Int32, +): + """SM100 entry point for sparse block iteration. + + SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use + simplified block processing that just calls producer_acquire without extras. + """ + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + q_phase_flipped = False + + if mask_empty: + # No masked blocks: process full list with Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = not full_empty + else: + # Process masked blocks with Q loading + kv_producer_state = load_block_list_sm100( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = True + + if not full_empty: + # Process full blocks without Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + + if q_phase_flipped: + q_producer_phase ^= 1 + + return kv_producer_state, q_producer_phase + + +@cute.jit +def get_total_block_count( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + if const_expr(full_block_cnt is not None): + return ( + mask_block_cnt[batch_idx, head_idx, m_block] + + full_block_cnt[batch_idx, head_idx, m_block] + ) + else: + return mask_block_cnt[batch_idx, head_idx, m_block] + + +@cute.jit +def handle_block_sparse_empty_tile_correction_sm100( + tidx: Int32, + q_stage: cutlass.Constexpr, + m_block_size: cutlass.Constexpr, + qhead_per_kvhead, + pack_gqa: cutlass.Constexpr, + is_split_kv: cutlass.Constexpr, + learnable_sink, + mLSE, + seqlen, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + split_idx: Int32, + sScale: cute.Tensor, + stats: list, + correction_epilogue: Callable, + thr_mma_pv: cute.core.ThrMma, + tOtOs: tuple[cute.Tensor], + sO: cute.Tensor, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + mbar_corr_epi_full_offset: Int32, + mbar_corr_epi_empty_offset: Int32, + softmax_corr_consumer_phase: Int32, + o_corr_consumer_phase: Int32, + corr_epi_producer_phase: Int32, + softmax_scale_log2: Float32, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, +): + """Handle the block-sparse case where a tile is fully masked: + * zero staged results + * seed stats + * satisfy the usual barrier protocol so downstream warps continue to make progress. + """ + LOG2_E = Float32(math.log2(math.e)) + + for stage in cutlass.range_constexpr(q_stage): + row_sum_value = Float32(1.0) + row_max_value = ( + -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None + ) + if const_expr(learnable_sink is not None): + sink_val = -Float32.inf + if const_expr(not pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + elif tidx < m_block_size: + q_head_idx = ( + (q_stage * m_block + stage) * m_block_size + tidx + ) % qhead_per_kvhead + head_idx * qhead_per_kvhead + sink_val = Float32(learnable_sink[q_head_idx]) + if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): + if row_max_value == -Float32.inf: + row_max_value = sink_val * (LOG2_E / softmax_scale_log2) + row_sum_value = Float32(1.0) + else: + row_sum_value = row_sum_value + utils.exp2f( + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + ) + if tidx < m_block_size: + scale_row_idx = tidx + stage * m_block_size + sScale[scale_row_idx] = row_sum_value + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[scale_row_idx + m_block_size * 2] = row_max_value + acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value + stats[stage] = (row_sum_value, row_max_value, acc_flag) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) + + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) + correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + stage, + m_block, + seqlen.seqlen_q, + Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs + sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, + ) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) + + softmax_corr_consumer_phase ^= 1 + o_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + return ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) + + +@cute.jit +def softmax_block_sparse_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + softmax_step: Callable, + mask_fn: Callable, + mask_fn_none: Callable, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + q_stage: cutlass.Constexpr, + stage_idx: Int32, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + + if total_block_cnt == 0: + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx) + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), # last block could oob + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=True, + mask_fn=partial(mask_fn_none, mask_seqlen=True), + ) + else: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=False, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + + return ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + total_block_cnt == 0, + ) diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py new file mode 100644 index 00000000000..cefb48e7e24 --- /dev/null +++ b/flash_attn/cute/block_sparsity.py @@ -0,0 +1,650 @@ +""" +Computes block-sparse attention masks for Flex Attention. + +This utility generates block sparsity patterns based on common attention masking +strategies (e.g., causal, sliding window). The resulting tensors define which +blocks are fully computed, which are partially computed (requiring a mask), and +which are skipped entirely. This is a temporary solution intended to be replaced +by a more robust preprocessing kernel in the future. +""" + +from typing import Tuple, Optional, Callable, List, NamedTuple +import torch +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + + +# placeholder +Config = type("Config", (), {}) + + +class BlockSparseTensors(NamedTuple): + mask_block_cnt: cute.Tensor + mask_block_idx: cute.Tensor + full_block_cnt: Optional[cute.Tensor] + full_block_idx: Optional[cute.Tensor] + + def __new_from_mlir_values__(self, values): + if len(values) == 2: + values = (*values, None, None) + return BlockSparseTensors(*values) + + +class BlockSparseTensorsTorch(NamedTuple): + mask_block_cnt: torch.Tensor + mask_block_idx: torch.Tensor + full_block_cnt: Optional[torch.Tensor] = None + full_block_idx: Optional[torch.Tensor] = None + + +def _expand_sparsity_tensor( + tensor: torch.Tensor, + expected_shape: Tuple[int, ...], + tensor_name: str, +) -> torch.Tensor: + """Check if we need to expand the tensor to expected shape, and do so if possible.""" + needs_expand = tensor.shape != expected_shape + if not needs_expand: + return tensor + can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape)) + if not can_expand: + raise ValueError( + f"{tensor_name} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." + ) + return tensor.expand(*expected_shape).contiguous() + + +def _check_and_expand_block( + name: str, + cnt: Optional[torch.Tensor], + idx: Optional[torch.Tensor], + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if (cnt is None) != (idx is None): + raise ValueError( + f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" + ) + if cnt is None or idx is None: + return None, None + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: + raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") + if not cnt.is_cuda or not idx.is_cuda: + raise ValueError(f"{name}_block tensors must live on CUDA") + expanded_cnt = _expand_sparsity_tensor(cnt, expected_count_shape, f"{name}_block_cnt") + expanded_idx = _expand_sparsity_tensor(idx, expected_index_shape, f"{name}_block_idx") + return expanded_cnt, expanded_idx + + +def normalize_block_sparse_tensors( + tensors: BlockSparseTensorsTorch, + *, + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], +) -> BlockSparseTensorsTorch: + if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + mask_cnt, mask_idx = _check_and_expand_block( + "mask", + tensors.mask_block_cnt, + tensors.mask_block_idx, + expected_count_shape, + expected_index_shape, + ) + if mask_cnt is None or mask_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + full_cnt, full_idx = _check_and_expand_block( + "full", + tensors.full_block_cnt, + tensors.full_block_idx, + expected_count_shape, + expected_index_shape, + ) + if full_cnt is not None and mask_cnt.device != full_cnt.device: + raise ValueError("All block sparse tensors must be on the same device") + + return BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) + + +def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: + return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) + + +def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]: + if not is_block_sparsity_enabled(tensors): + return None + + mask_block_cnt_tensor = from_dlpack( + tensors.mask_block_cnt.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=2) + mask_block_idx_tensor = from_dlpack( + tensors.mask_block_idx.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=3) + full_block_cnt_tensor = ( + from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if tensors.full_block_cnt is not None + else None + ) + full_block_idx_tensor = ( + from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if tensors.full_block_idx is not None + else None + ) + + return BlockSparseTensors( + mask_block_cnt_tensor, + mask_block_idx_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + ) + + +def compute_block_sparsity( + config: Config, + mask_mod_flex: Optional[Callable], + device: str, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + aux_tensors: Optional[List[torch.Tensor]] = None, +) -> Tuple[ + Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] +]: + """ + Computes block sparsity tensors from a given masking function. + + This function serves as the main entry point for generating block-sparse masks. + It dispatches to specialized handlers for variable-length and fixed-length + sequences. + + Args: + config: A configuration object containing model and tiling parameters. + mask_mod_flex: The mask function for generic flex attention patterns. + device: The device to create tensors on (e.g., 'cuda'). + cu_seqlens_q: Cumulative sequence lengths for Q (for varlen). + cu_seqlens_k: Cumulative sequence lengths for K (for varlen). + aux_tensors: A list of auxiliary tensors, e.g., for document masking. + + Returns: + A tuple of four tensors: + - `full_block_cnt`: (batch, nheads, n_blocks_q) - Count of full n blocks per m block. + - `full_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of full n blocks. + - `mask_block_cnt`: (batch, nheads, n_blocks_q) - Count of partial n blocks per m block. + - `mask_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of partial n blocks. + Returns (None, None, None, None) if masking is disabled. + """ + if not config.use_mask_mod or mask_mod_flex is None: + return None, None, None, None + + if cu_seqlens_q is not None: + # Handle variable-length sequences + return _compute_varlen_sparsity(config, mask_mod_flex, device, cu_seqlens_q, cu_seqlens_k) + else: + # Handle fixed-length sequences + return _compute_sparsity(config, device, aux_tensors) + + +## --------------------------------------------------------------------------- +## Fixed-Length Sequence Kernels +## --------------------------------------------------------------------------- + + +def _compute_sparsity( + config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for fixed-length sequences.""" + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + # Pre-allocate output tensors + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + + # --- Identity Mask --- + # All blocks are fully computed. + if config.mask_mod_name == "identity": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + full_block_cnt[:, :, q_block_idx] = n_blocks_k + full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Identity Partial Mask --- + # All blocks are partially computed (masked). + elif config.mask_mod_name == "identity_partial": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + mask_block_cnt[:, :, q_block_idx] = n_blocks_k + mask_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Block Causal Mask --- + elif config.mask_mod_name == "block_causal": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + causal_indices = k_blocks[k_blocks <= q_block_idx] + num_causal_indices = len(causal_indices) + if num_causal_indices > 0: + full_block_cnt[:, :, q_block_idx] = num_causal_indices + full_block_idx[:, :, q_block_idx, :num_causal_indices] = causal_indices + + # --- Causal and Sliding Window Masks --- + elif config.mask_mod_name in ["causal", "sliding_window"]: + q_block_indices = torch.arange(n_blocks_q, device=device) + k_block_indices = torch.arange(n_blocks_k, device=device) + + q_starts = q_block_indices * config.tile_m + q_ends = torch.minimum( + (q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device) + ) + k_starts = k_block_indices * config.tile_n + k_ends = torch.minimum( + (k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device) + ) + + # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) + q_starts, q_ends = q_starts.unsqueeze(1), q_ends.unsqueeze(1) + k_starts, k_ends = k_starts.unsqueeze(0), k_ends.unsqueeze(0) + + offset = config.seqlen_k - config.seqlen_q + + if config.mask_mod_name == "causal": + is_full = (k_ends - 1) <= (q_starts + offset) + # min(k_pos) <= max(q_pos) AND not is_full. + is_partial = (k_starts <= (q_ends - 1 + offset)) & ~is_full + + else: # sliding_window + window_size = getattr(config, "window_size", 1024) + is_full = (k_ends - 1 <= q_starts + offset) & ( + k_starts >= q_ends - 1 + offset - (window_size - 1) + ) + # A block is EMPTY if no (q, k) pairs satisfy the constraint. + is_empty = (k_starts > q_ends - 1 + offset) | ( + k_ends - 1 < q_starts + offset - (window_size - 1) + ) + # A block is PARTIAL if it's not empty and not full. + is_partial = ~is_empty & ~is_full + + # Populate indices based on the computed block classifications + for q_block_idx in range(n_blocks_q): + full_indices = k_block_indices[is_full[q_block_idx]] + if len(full_indices) > 0: + full_block_cnt[:, :, q_block_idx] = len(full_indices) + full_block_idx[:, :, q_block_idx, : len(full_indices)] = full_indices + + partial_indices = k_block_indices[is_partial[q_block_idx]] + if len(partial_indices) > 0: + mask_block_cnt[:, :, q_block_idx] = len(partial_indices) + mask_block_idx[:, :, q_block_idx, : len(partial_indices)] = partial_indices + + elif config.mask_mod_name == "document": + raise NotImplementedError("Block sparsity for document masking not yet implemented") + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + + +## --------------------------------------------------------------------------- +## Variable-Length Sequence Kernels +## --------------------------------------------------------------------------- + + +def _compute_varlen_sparsity( + config: Config, + mask_mod_flex: Callable, + device: str, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for variable-length sequences.""" + assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" + assert cu_seqlens_q.shape[0] == config.batch_size + 1 + assert cu_seqlens_k.shape[0] == config.batch_size + 1 + + # In varlen, each sequence can have a different number of Q blocks. + # We pad up to the maximum number of Q blocks in the batch. + max_m_blocks = 0 + for seq_idx in range(config.batch_size): + seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + max_m_blocks = max(max_m_blocks, n_blocks_q) + + # The number of K blocks is determined by the total length of all sequences. + total_k_len = cu_seqlens_k[-1].item() + max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n + + # Pre-allocate padded output tensors + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) + + # Process each sequence in the batch individually + for seq_idx in range(config.batch_size): + seq_start_q = cu_seqlens_q[seq_idx].item() + seq_end_q = cu_seqlens_q[seq_idx + 1].item() + seq_len_q = seq_end_q - seq_start_q + + seq_start_k = cu_seqlens_k[seq_idx].item() + seq_end_k = cu_seqlens_k[seq_idx + 1].item() + seq_len_k = seq_end_k - seq_start_k + + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + + # Global block indices are relative to the start of the entire batch tensor + first_m_block_global = seq_start_q // config.tile_m + first_n_block_global = seq_start_k // config.tile_n + + common_args = { + "full_block_cnt": full_block_cnt, + "full_block_idx": full_block_idx, + "mask_block_cnt": mask_block_cnt, + "mask_block_idx": mask_block_idx, + "seq_idx": seq_idx, + "n_blocks_q": n_blocks_q, + "n_blocks_k": n_blocks_k, + "seq_start_q": seq_start_q, + "seq_end_q": seq_end_q, + "seq_start_k": seq_start_k, + "seq_end_k": seq_end_k, + "first_n_block_global": first_n_block_global, + "tile_m": config.tile_m, + "tile_n": config.tile_n, + "device": device, + } + + if config.mask_mod_name == "causal": + _compute_causal_varlen_blocks(**common_args) + elif config.mask_mod_name == "sliding_window": + window_size = getattr(config, "window_size", 1024) + _compute_sliding_window_varlen_blocks(**common_args, window_size=window_size) + elif config.mask_mod_name == "identity": + _compute_identity_varlen_blocks( + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, + ) + else: + # Generic case relies on sampling the user-provided mask function + _compute_generic_varlen_blocks( + **common_args, + mask_mod_flex=mask_mod_flex, + seq_len_q=seq_len_q, + seq_len_k=seq_len_k, + num_heads=config.nheads, + nheads_kv=config.nheads_kv, + ) + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + + +def _classify_varlen_block( + m_local: int, + n_local: int, + seq_start_q: int, + seq_end_q: int, + seq_start_k: int, + seq_end_k: int, + tile_m: int, + tile_n: int, + is_full_fn: Callable, + is_partial_fn: Callable, +) -> Tuple[bool, bool]: + """Helper to classify a varlen block as full, partial, or empty.""" + m_start_global = seq_start_q + m_local * tile_m + m_end_global = min(seq_start_q + (m_local + 1) * tile_m, seq_end_q) + n_start_global = seq_start_k + n_local * tile_n + n_end_global = min(seq_start_k + (n_local + 1) * tile_n, seq_end_k) + + # Use sequence-local coordinates for the logical check + m_start_local = m_start_global - seq_start_q + m_end_local = m_end_global - seq_start_q + n_start_local = n_start_global - seq_start_k + n_end_local = n_end_global - seq_start_k + + is_full = is_full_fn(m_start_local, m_end_local, n_start_local, n_end_local) + is_partial = ( + is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full + ) + + # Any block that touches the sequence boundary is partial because it requires masking. + at_boundary = (m_end_global > seq_end_q) or (n_end_global > seq_end_k) + + return is_full and not at_boundary, is_partial or (is_full and at_boundary) + + +def _compute_causal_varlen_blocks( + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + device, + **kwargs, +): + """Computes causal block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: (m_start >= n_end - 1) + is_partial_fn = lambda m_start, m_end, n_start, n_end: (m_end - 1 >= n_start) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + + +def _compute_sliding_window_varlen_blocks( + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + window_size, + device, + **kwargs, +): + """Computes sliding window block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: (n_end - 1 <= m_start) and ( + n_start >= m_start - window_size + 1 + ) + is_partial_fn = lambda m_start, m_end, n_start, n_end: not ( + (n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1) + ) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + + +def _compute_identity_varlen_blocks( + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, + **kwargs, +): + """Computes identity (all-attend) block sparsity for a single varlen sequence.""" + n_blocks_global = torch.arange( + first_n_block_global, first_n_block_global + n_blocks_k, device=device, dtype=torch.int32 + ) + for m_local in range(n_blocks_q): + full_block_cnt[seq_idx, :, m_local] = n_blocks_k + full_block_idx[seq_idx, :, m_local, :n_blocks_k] = n_blocks_global + + +def _compute_generic_varlen_blocks( + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + mask_mod_flex, + seq_idx, + num_heads, + n_blocks_q, + n_blocks_k, + seq_len_q, + seq_len_k, + first_n_block_global, + tile_m, + tile_n, + nheads_kv, + device, + **kwargs, +): + """Generic sampling-based block classification for a varlen sequence.""" + qhead_per_kvhead = num_heads // nheads_kv + + for h_q in range(num_heads): + h_kv = h_q // qhead_per_kvhead + for m_local in range(n_blocks_q): + m_start_local = m_local * tile_m + m_end_local = min((m_local + 1) * tile_m, seq_len_q) + + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + n_start_local = n_local * tile_n + n_end_local = min((n_local + 1) * tile_n, seq_len_k) + + # Sample points within the block (corners and center) to classify it. + # Coordinates are sequence-local, as required by mask_mod_flex. + sample_positions = [ + (m_start_local, n_start_local), + (m_start_local, n_end_local - 1), + (m_end_local - 1, n_start_local), + (m_end_local - 1, n_end_local - 1), + ((m_start_local + m_end_local) // 2, (n_start_local + n_end_local) // 2), + ] + + unmasked_count = sum( + 1 + for q_pos, k_pos in sample_positions + if mask_mod_flex(seq_idx, h_q, q_pos, k_pos, seq_len_q, seq_len_k) + ) + + n_block_global = first_n_block_global + n_local + if unmasked_count == len(sample_positions): # All samples unmasked -> full + full_blocks.append(n_block_global) + elif unmasked_count > 0: # Some unmasked -> partial + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) + full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) + if partial_blocks: + mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py new file mode 100644 index 00000000000..acaeac794c5 --- /dev/null +++ b/flash_attn/cute/compute_block_sparsity.py @@ -0,0 +1,400 @@ +from functools import partial +from typing import Callable, Optional, Tuple + +import cutlass +from cutlass import Boolean, Int32, Int8, const_expr +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import torch + +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar + + +class BlockSparsityKernel: + """Block sparsity kernel for FlexAttention. + + This kernel computes `mask_mod` for every token of each block + to determine if an n block is full, masked, or neither. + + Writes block counts and indices to a BlockSparseTensors object. + + When use_fast_sampling=True, uses 5-point sampling (4 corners + center) + which is much faster but only suitable for masks where this is sufficient. + """ + + def __init__( + self, + mask_mod: Callable, + tile_mn: Tuple[int, int], + compute_full_blocks: bool = True, + use_aux_tensors: bool = False, + use_fast_sampling: bool = False, + ): + self.mask_mod = mask_mod + self.tile_mn = tile_mn + self.compute_full_blocks = compute_full_blocks + self.use_aux_tensors = use_aux_tensors + self.use_fast_sampling = use_fast_sampling + + @cute.jit + def __call__( + self, + blocksparse_tensors: BlockSparseTensors, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + + if const_expr(self.compute_full_blocks): + assert self.full_cnt is not None and self.full_idx is not None, ( + "full block tensors must be provided when computing full blocks" + ) + + batch_size, num_heads, num_m_blocks, num_n_blocks = list(self.mask_idx.shape) + grid = [num_m_blocks, num_heads, batch_size] + + # Fast sampling uses only 5 threads (4 corners + center), full sampling uses 1 thread per row + if const_expr(self.use_fast_sampling): + num_threads = 5 + self.num_warps = 1 + else: + num_threads = self.tile_mn[0] + self.num_warps = (num_threads + 32 - 1) // 32 + + self.kernel( + self.mask_cnt, + self.mask_idx, + self.full_cnt, + self.full_idx, + num_n_blocks, + seqlen_q, + seqlen_k, + aux_tensors, + ).launch(grid=grid, block=[num_threads, 1, 1]) + + @cute.kernel + def kernel( + self, + mask_cnt: cute.Tensor, + mask_idx: cute.Tensor, + full_cnt: cute.Tensor, + full_idx: cute.Tensor, + num_n_blocks: Int32, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + # Store seqlens as instance variables for use in the kernel + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + m_block, head_idx, batch_idx = cute.arch.block_idx() + + ssa = partial(scalar_to_ssa, dtype=Int32) + + @cute.struct + class SharedStorage: + reduction_buffer_smem: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024 + ] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage, 16) + + reduction_buffer = storage.reduction_buffer_smem.get_tensor( + cute.make_layout((self.num_warps, 2)) + ) + + num_mask_blocks = Int32(0) + num_full_blocks = Int32(0) + + for n_block in cutlass.range(num_n_blocks, unroll_full=True): + m_base = m_block * self.tile_mn[0] + n_base = n_block * self.tile_mn[1] + + if const_expr(self.use_fast_sampling): + # Fast path: 5-point sampling (4 corners + center) + # Out-of-bounds indices are treated as masked (False) + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + q_idx = Int32(0) + kv_idx = Int32(0) + + if tidx == 0: + # Top-left corner (0, 0) + q_idx = m_base + kv_idx = n_base + elif tidx == 1: + # Top-right corner + q_idx = m_base + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 2: + # Bottom-left corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + elif tidx == 3: + # Bottom-right corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 4: + # Center point + q_idx = m_base + self.tile_mn[0] // 2 + kv_idx = n_base + self.tile_mn[1] // 2 + + # Check bounds and determine if this thread has a valid index pair + if q_idx < self.seqlen_q and kv_idx < self.seqlen_k: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + kv_idx_ssa = ssa(kv_idx) + thread_result = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), ssa(head_idx), q_idx_ssa, kv_idx_ssa, aux_tensors + ) + ) + else: + thread_is_valid = Boolean(False) + + # Use vote_any_sync to see if any valid thread found unmasked or masked + # Only count results from threads that checked valid indices + has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) + has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + + else: + # Full path: check all elements in the block + # Track if this thread's row has any masked or unmasked elements + thread_has_unmasked = Boolean(False) + thread_has_masked = Boolean(False) + thread_is_valid = Boolean(False) + + # Each thread handles 1 row + q_idx = m_base + tidx + kv_idx = Int32(0) + if tidx < self.tile_mn[0] and q_idx < self.seqlen_q: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + + # Loop over all columns in this row + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + kv_idx_ssa = ssa(kv_idx) + + # Only check elements within valid sequence bounds + if kv_idx < self.seqlen_k: + # Direct scalar call + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + q_idx_ssa, + kv_idx_ssa, + aux_tensors, + ) + ) + + # Update tracking flags + if mask_val: + thread_has_unmasked = Boolean(True) + else: + thread_has_masked = Boolean(True) + + # Block-level reduction to combine results across all threads + # Only count votes from threads that checked valid indices + warp_has_unmasked_mask = cute.arch.vote_any_sync( + thread_has_unmasked & thread_is_valid + ) + warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) + + # lane 0 writes the ballot mask to shared memory + lane_id = tidx % 32 + if lane_id == 0: + # Store as Int8 + reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) + + cute.arch.sync_threads() + + # Thread 0 ORs all warp results together + has_unmasked = Boolean(False) + has_masked = Boolean(False) + if tidx == 0: + for w in cutlass.range(self.num_warps): + if reduction_buffer[w, 0]: + has_unmasked = Boolean(True) + if reduction_buffer[w, 1]: + has_masked = Boolean(True) + + # Only thread 0 updates the output arrays (common to both paths) + if tidx == 0: + # Block classification based on what we found: + # - If has_masked and has_unmasked: partial block (needs masking) + # - If only has_unmasked: full block (no masking needed) + # - If only has_masked: skip this block entirely + is_partial = Boolean(has_masked and has_unmasked) + is_full = Boolean(has_unmasked and (not has_masked)) + + if is_partial: + mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + num_mask_blocks += 1 + elif is_full and const_expr(self.compute_full_blocks): + full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + num_full_blocks += 1 + + # Only thread 0 writes back the counts + if tidx == 0: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + + +def compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + mask_mod: Callable, + aux_tensors: Optional[list], # list[cute.Tensor] + device, + compute_full_blocks: bool = True, + use_fast_sampling: bool = False, +) -> Tuple[BlockSparseTensors, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes block sparsity for a given `mask_mod`. + + Args: + tile_m: The tile size for the m dimension. + tile_n: The tile size for the n dimension. + batch_size: The batch size. + num_heads: The number of heads. + seqlen_q: The sequence length for the query. + seqlen_k: The sequence length for the key. + mask_mod: The `mask_mod` callable to use. + aux_tensors: A list of auxiliary tensors. + device: The device to use. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. + + Returns: + A tuple of `BlockSparseTensors` and the underlying torch tensors. + """ + num_m_blocks = (seqlen_q + tile_m - 1) // tile_m + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + full_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + + # Convert to cute tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + mask_mod_hash = hash_callable(mask_mod) + + compile_key = ( + tile_m, + tile_n, + mask_mod_hash, + compute_full_blocks, + aux_tensors is not None, + use_fast_sampling, + ) + if compile_key not in compute_block_sparsity.compile_cache: + kernel = BlockSparsityKernel( + mask_mod, + tile_mn=(tile_m, tile_n), + compute_full_blocks=True, + use_aux_tensors=aux_tensors is not None, + use_fast_sampling=use_fast_sampling, + ) + + compute_block_sparsity.compile_cache[compile_key] = cute.compile( + kernel, + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + compute_block_sparsity.compile_cache[compile_key]( + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + # Return both the BlockSparseTensors (cute) and the underlying torch tensors + return blocksparse_tensors, (full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx) + + +compute_block_sparsity.compile_cache = {} + + +def run(): + """Test the BlockSparsityKernel with a simple causal mask.""" + + print("Testing BlockSparsityKernel...") + + # Configuration + batch_size = 2 + num_heads = 2 + seqlen_q = 16384 + seqlen_k = 16384 + tile_m, tile_n = 128, 128 # Use very small tiles for initial testing + + # Define a simple causal mask function + @cute.jit + def causal_mask(batch_idx, head_idx, q_idx, kv_idx, aux_tensors): + """Simple causal mask: only attend to positions <= current position.""" + return q_idx >= kv_idx + + try: + compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + causal_mask, + None, + device="cuda", + ) + print("Kernel execution completed!") + except Exception as e: + print(f"Kernel execution failed: {e}") + + +if __name__ == "__main__": + run() diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py new file mode 100644 index 00000000000..cfdcbdb80a0 --- /dev/null +++ b/flash_attn/cute/copy_utils.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py new file mode 100644 index 00000000000..6deeac30d34 --- /dev/null +++ b/flash_attn/cute/cute_dsl_utils.py @@ -0,0 +1,124 @@ +# Copyright (c) 2025, Tri Dao. + +import os +import pathlib +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + output = cute_compile_og(*args, **kwargs) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py new file mode 100644 index 00000000000..c56ea89e798 --- /dev/null +++ b/flash_attn/cute/fast_math.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025, Tri Dao. + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 03d41b31e6b..ce0a1b6e5e9 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -11,12 +11,14 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp -import cutlass.utils.ampere_helpers as sm80_utils_basic +from cutlass import Float32, Int32 +import cutlass.utils as utils_basic from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionBackwardSm80: @@ -31,6 +33,7 @@ def __init__( num_stages_Q: int = 2, num_stages_dO: int = 2, num_threads: int = 256, + pack_gqa: bool = False, is_causal: bool = False, SdP_swapAB: bool = False, dKV_swapAB: bool = False, @@ -69,6 +72,7 @@ def __init__( self.m_block_size = m_block_size self.n_block_size = n_block_size self.num_threads = num_threads + self.pack_gqa = pack_gqa self.is_causal = is_causal self.num_stages_Q = num_stages_Q self.num_stages_dO = num_stages_dO @@ -125,7 +129,7 @@ def can_implement( smem_usage_V = n_block_size * head_dim_v * 2 smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False return True @@ -141,6 +145,10 @@ def _check_type( mdQaccum_type: Type[cutlass.Numeric], mdK_type: Type[cutlass.Numeric], mdV_type: Type[cutlass.Numeric], + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, ): if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): raise TypeError("All tensors must have the same data type") @@ -158,6 +166,14 @@ def _check_type( raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") + if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensQ tensor must be Int32") + if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensK tensor must be Int32") + if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedQ tensor must be Int32") + if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedK tensor must be Int32") assert mQ_type == self.dtype def _setup_attributes(self): @@ -245,11 +261,22 @@ def _setup_attributes(self): self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout) self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout) async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width - atom_async_copy_accum = cute.make_copy_atom( - cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), - cutlass.Float32, - num_bits_per_copy=universal_copy_bits, - ) + + # I think we wouldn't require this with smarter padding + if cutlass.const_expr(not self.varlen_q): + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + else: + async_copy_elems_accum = 1 + atom_async_copy_accum = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=cutlass.Float32.width, + ) self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( atom_async_copy_accum, cute.make_layout(self.num_threads), @@ -262,25 +289,25 @@ def _setup_attributes(self): cute.make_layout(self.num_threads), cute.make_layout(1) ) - if self.qhead_per_kvhead > 1: + if cutlass.const_expr(self.qhead_per_kvhead > 1): self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum def _get_tiled_mma(self): num_mma_warps = self.num_threads // 32 - AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if not self.SdP_swapAB else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) + AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) tiled_mma_sdp = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutSdP, permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16), ) - AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if not self.dKV_swapAB else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) + AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) tiled_mma_dkv = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdKV, permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16), ) - AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if not self.dQ_swapAB else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) tiled_mma_dq = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdQ, @@ -293,7 +320,7 @@ def _get_shared_storage_cls(self): cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] sLSE_struct, sdPsum_struct = [ cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128] @@ -343,19 +370,55 @@ def __call__( mdV: cute.Tensor, softmax_scale: cutlass.Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, ): + assert mdQ_semaphore is None, "semaphore not supported yet" # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None - for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV))) + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] + self.varlen_q = (mCuSeqlensQ is not None) self._setup_attributes() SharedStorage = self._get_shared_storage_cls() tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() - # grid_dim: (n_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mK.shape[1], self.n_block_size), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[0]), + + num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2] + + if cutlass.const_expr(mCuSeqlensK is not None): + TileScheduler = SingleTileVarlenScheduler + num_batch = mCuSeqlensK.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_batch = mK.shape[0] + + # Uses seqlen k, etc. since main bwd kernel's blocks are over n + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mK.shape[1], self.n_block_size), + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mK.shape[2], + headdim_v=mV.shape[2], + total_q=mK.shape[0], + tile_shape_mn=(self.n_block_size, self.m_block_size), + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2 = softmax_scale * math.log2(math.e) self.kernel( mQ, @@ -367,6 +430,10 @@ def __call__( mdQaccum, mdK, mdV, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, softmax_scale, softmax_scale_log2, self.sQ_layout, @@ -386,6 +453,8 @@ def __call__( tiled_mma_dkv, tiled_mma_dq, SharedStorage, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -402,9 +471,13 @@ def kernel( mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdQaccu: cute.Tensor, + mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, sQ_layout: cute.ComposedLayout, @@ -424,301 +497,333 @@ def kernel( tiled_mma_dkv: cute.TiledMma, tiled_mma_dq: cute.TiledMma, SharedStorage: cutlass.Constexpr, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - n_block, head_idx, batch_idx = cute.arch.block_idx() - - m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) - m_block_min = 0 - if self.is_causal: - m_block_min = max( - (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, - m_block_min, + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + + if work_tile.is_valid_tile: + seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + + m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) + m_block_min = 0 + if cutlass.const_expr(self.is_causal): + m_block_min = max( + (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size, + m_block_min, + ) + # TODO: return early if m_block_max == 0 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.m_block_size, self.head_dim_padded) + blkK_shape = (self.n_block_size, self.head_dim_padded) + blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkdO_shape = (self.m_block_size, self.head_dim_v_padded) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[batch_idx, None, head_idx, None] + mLSE_cur = mLSE[batch_idx, head_idx, None] + mdO_cur = mdO[batch_idx, None, head_idx, None] + mdPsum_cur = mdPsum[batch_idx, head_idx, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] + else: + padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) + mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]) + head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)] + + # (m_block_size, head_dim, m_block) + gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0)) + # (n_block_size, head_dim) + gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0)) + # (n_block_size, head_dim_v) + gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0)) + # (m_block_size, head_dim_v, m_block) + gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0)) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if cutlass.const_expr(not self.share_QV_smem): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + sdO = storage.sdO.get_tensor(sdO_layout) + sP = storage.sP.get_tensor(sPdS_layout) + sdS = storage.sdS.get_tensor(sPdS_layout) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sLSE_layout) + sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) + sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) + + # Transpose view of tensors for tiled mma + sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] + + gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) + gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) + gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tQgQ = gmem_thr_copy_QK.partition_S(gQ) + tQsQ = gmem_thr_copy_QK.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K) + tKgK = gmem_thr_copy_QK.partition_S(gK) + tKsK = gmem_thr_copy_QK.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K) + tVgV = gmem_thr_copy_VdO.partition_S(gV) + tVsV = gmem_thr_copy_VdO.partition_D(sV) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) + tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) + tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) + tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) + tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) + tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) + thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) + thr_mma_dq = tiled_mma_dq.get_slice(tidx) + acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) + acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) + acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) + acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) + acc_dK.fill(0.0) + acc_dV.fill(0.0) + + tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) + tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) + + LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) + tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] + tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, ) - # TODO: return early if m_block_max == 0 + smem_copy_atom_transposed = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + ) + smem_thr_copy_QdO = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + smem_thr_copy_KV = utils.make_tiled_copy_B( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + # TODO: should this be smem_copy_atom_transposed? + smem_thr_copy_PdSt = utils.make_tiled_copy_A( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_QdOt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_dS = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + smem_thr_copy_Kt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + # TODO: what's the number of bits? What if SdP_swapAB + r2s_thr_copy_PdS = cute.make_tiled_copy_C( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ), + tiled_mma_sdp, + ).get_slice(tidx) - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkQ_shape = (self.m_block_size, self.head_dim_padded) - blkK_shape = (self.n_block_size, self.head_dim_padded) - blkV_shape = (self.n_block_size, self.head_dim_v_padded) - blkdO_shape = (self.m_block_size, self.head_dim_v_padded) - # (m_block_size, head_dim, m_block) - gQ = cute.local_tile(mQ[batch_idx, None, head_idx, None], blkQ_shape, (None, 0)) - # (n_block_size, head_dim) - head_idx_kv = head_idx // self.qhead_per_kvhead - gK = cute.local_tile(mK[batch_idx, None, head_idx_kv, None], blkK_shape, (n_block, 0)) - # (n_block_size, head_dim_v) - gV = cute.local_tile(mV[batch_idx, None, head_idx_kv, None], blkV_shape, (n_block, 0)) - # (m_block_size, head_dim_v, m_block) - gdO = cute.local_tile(mdO[batch_idx, None, head_idx, None], blkdO_shape, (None, 0)) - gLSE = cute.local_tile(mLSE[batch_idx, head_idx, None], (self.m_block_size,), (None,)) - gdPsum = cute.local_tile(mdPsum[batch_idx, head_idx, None], (self.m_block_size,), (None,)) - gdQaccum = cute.local_tile(mdQaccu[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) + tSsQ = smem_thr_copy_QdO.partition_S(sQ) + tdPsdO = smem_thr_copy_QdO.partition_S(sdO) + tSsK = smem_thr_copy_KV.partition_S(sK) + tdPsV = smem_thr_copy_KV.partition_S(sV) + tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) + tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) + tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) + tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) + tdQsdS = smem_thr_copy_dS.partition_S(sdS) + tdQsKt = smem_thr_copy_Kt.partition_S(sKt) + tPsP = r2s_thr_copy_PdS.partition_D(sP) + tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - smem = cutlass.utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - sQ = storage.sQ.get_tensor(sQ_layout) - sK = storage.sK.get_tensor(sK_layout) - if cutlass.const_expr(not self.share_QV_smem): - sV = storage.sV.get_tensor(sV_layout) - else: - sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) - sdO = storage.sdO.get_tensor(sdO_layout) - sP = storage.sP.get_tensor(sPdS_layout) - sdS = storage.sdS.get_tensor(sPdS_layout) - sLSE = storage.sLSE.get_tensor(sLSE_layout) - sdPsum = storage.sdPsum.get_tensor(sLSE_layout) - sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) - sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) - - # Transpose view of tensors for tiled mma - sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] - - gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) - gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) - gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) - gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) - # (CPY_Atom, CPY_M, CPY_K, m_block) - tQgQ = gmem_thr_copy_QK.partition_S(gQ) - tQsQ = gmem_thr_copy_QK.partition_D(sQ) - # (CPY_Atom, CPY_N, CPY_K) - tKgK = gmem_thr_copy_QK.partition_S(gK) - tKsK = gmem_thr_copy_QK.partition_D(sK) - # (CPY_Atom, CPY_N, CPY_K) - tVgV = gmem_thr_copy_VdO.partition_S(gV) - tVsV = gmem_thr_copy_VdO.partition_D(sV) - # (CPY_Atom, CPY_M, CPY_K, m_block) - tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) - tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) - tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) - tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) - tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) - tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) - tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQcQ = gmem_thr_copy_QK.partition_S(cQ) + t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdOcdO = tQcQ + t0dOcdO = t0QcQ + else: + cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) + t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) + cLSE = cute.make_identity_tensor((self.m_block_size,)) + tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) - thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) - thr_mma_dq = tiled_mma_dq.get_slice(tidx) - acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) - acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) - acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) - acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) - acc_dK.fill(0.0) - acc_dV.fill(0.0) - - tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) - tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) - tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) - tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) - tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) - tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) - tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) - tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) - tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) - tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) - - LSEslice = (None, 0, None) if not self.SdP_swapAB else (0, None, None) - tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] - tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, - ) - smem_copy_atom_transposed = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, - ) - smem_thr_copy_QdO = utils.make_tiled_copy_A( - smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB - ).get_slice(tidx) - smem_thr_copy_KV = utils.make_tiled_copy_B( - smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB - ).get_slice(tidx) - # TODO: should this be smem_copy_atom_transposed? - smem_thr_copy_PdSt = utils.make_tiled_copy_A( - smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB - ).get_slice(tidx) - smem_thr_copy_QdOt = utils.make_tiled_copy_B( - smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB - ).get_slice(tidx) - smem_thr_copy_dS = utils.make_tiled_copy_A( - smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB - ).get_slice(tidx) - smem_thr_copy_Kt = utils.make_tiled_copy_B( - smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB - ).get_slice(tidx) - # TODO: what's the number of bits? What if SdP_swapAB - r2s_thr_copy_PdS = utils.make_tiled_copy_C( - cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width - ), - tiled_mma_sdp, - ).get_slice(tidx) - - tSsQ = smem_thr_copy_QdO.partition_S(sQ) - tdPsdO = smem_thr_copy_QdO.partition_S(sdO) - tSsK = smem_thr_copy_KV.partition_S(sK) - tdPsV = smem_thr_copy_KV.partition_S(sV) - tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) - tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) - tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) - tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) - tdQsdS = smem_thr_copy_dS.partition_S(sdS) - tdQsKt = smem_thr_copy_Kt.partition_S(sKt) - tPsP = r2s_thr_copy_PdS.partition_D(sP) - tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) + d_head = mQ.shape[cute.rank(mQ) - 1] + d_head_v = mdO.shape[cute.rank(mdO) - 1] - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tQcQ = gmem_thr_copy_QK.partition_S(cQ) - t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) - if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): - tdOcdO = tQcQ - t0dOcdO = t0QcQ - else: - cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) - t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) - cLSE = cute.make_identity_tensor((self.m_block_size,)) - tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) - - # Allocate predicate tensors for m and n, here we only allocate the tile of k, and - # use "if" on the mn dimension. - # This is to reduce register pressure and gets 2-3% performance gain. - tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[3]) - if cutlass.const_expr(self.same_hdim_kv): - tdOpdO = tQpQ - else: - tdOpdO = utils.predicate_k(tdOcdO, limit=mdO.shape[3]) - - # group parameters for compute_one_m_block - mma_params = SimpleNamespace( - thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, - tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, - tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, - tdQrdS=tdQrdS, tdQrK=tdQrK, - acc_dK=acc_dK, acc_dV=acc_dV, - ) - smem_copy_params = SimpleNamespace( - smem_thr_copy_QdO=smem_thr_copy_QdO, - smem_thr_copy_KV=smem_thr_copy_KV, - smem_thr_copy_PdSt=smem_thr_copy_PdSt, - smem_thr_copy_QdOt=smem_thr_copy_QdOt, - smem_thr_copy_dS=smem_thr_copy_dS, - smem_thr_copy_Kt=smem_thr_copy_Kt, - r2s_thr_copy_PdS=r2s_thr_copy_PdS, - tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, - tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, - tPsP=tPsP, tdSsdS=tdSsdS, - tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, - tdQsdS=tdQsdS, tdQsKt=tdQsKt, - ) - gmem_copy_params = SimpleNamespace( - gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum - ) - seqlen = SeqlenInfo(batch_idx, mQ.shape[1], mK.shape[1]) - load_Q_LSE = partial( - self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, - tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, - tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q - ) - load_dO_dPsum = partial( - self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, - tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, - tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q - ) - compute_one_m_block = partial( - self.compute_one_m_block, mma_params=mma_params, - smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, - load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, - m_block_max=m_block_max, - softmax_scale_log2=softmax_scale_log2, - ) + tQpQ = utils.predicate_k(tQcQ, limit=d_head) + if cutlass.const_expr(self.same_hdim_kv): + tdOpdO = tQpQ + else: + tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v) - # /////////////////////////////////////////////////////////////////////////////// - # Prologue - # /////////////////////////////////////////////////////////////////////////////// - # Start async loads of the last mn-tile, where we take care of the mn residue - self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, - headdim=mV.shape[3]) - if cutlass.const_expr(self.V_in_regs): + # group parameters for compute_one_m_block + mma_params = SimpleNamespace( + thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, + tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, + tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, + tdQrdS=tdQrdS, tdQrK=tdQrK, + acc_dK=acc_dK, acc_dV=acc_dV, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_QdO=smem_thr_copy_QdO, + smem_thr_copy_KV=smem_thr_copy_KV, + smem_thr_copy_PdSt=smem_thr_copy_PdSt, + smem_thr_copy_QdOt=smem_thr_copy_QdOt, + smem_thr_copy_dS=smem_thr_copy_dS, + smem_thr_copy_Kt=smem_thr_copy_Kt, + r2s_thr_copy_PdS=r2s_thr_copy_PdS, + tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, + tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, + tPsP=tPsP, tdSsdS=tdSsdS, + tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, + tdQsdS=tdQsdS, tdQsKt=tdQsKt, + ) + gmem_copy_params = SimpleNamespace( + gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum + ) + load_Q_LSE = partial( + self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, + tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, + tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + load_dO_dPsum = partial( + self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, + tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, + tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + compute_one_m_block = partial( + self.compute_one_m_block, mma_params=mma_params, + smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, + load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, + m_block_max=m_block_max, + softmax_scale_log2=softmax_scale_log2, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, + headdim=d_head_v) + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_commit_group() + self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, + headdim=d_head) cute.arch.cp_async_commit_group() - self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, - headdim=mK.shape[3]) - cute.arch.cp_async_commit_group() - if cutlass.const_expr(self.V_in_regs): - cute.arch.cp_async_wait_group(1) - cute.arch.barrier() - tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) - cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) - # Sync to avoid loading Q to smem_q, which overlaps with smem_v - cute.arch.barrier() + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_wait_group(1) + cute.arch.barrier() + tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) + cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) + # Sync to avoid loading Q to smem_q, which overlaps with smem_v + cute.arch.barrier() - m_block = m_block_min - assert self.num_stages_Q >= self.num_stages_dO - for stage in range(self.num_stages_Q): - if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): - if stage == 0 or m_block + stage < m_block_max: - load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) - cute.arch.cp_async_commit_group() - if cutlass.const_expr(stage < self.num_stages_dO): - if stage == 0 or m_block + stage < m_block_max: - load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) - cute.arch.cp_async_commit_group() + m_block = m_block_min + assert self.num_stages_Q >= self.num_stages_dO + for stage in cutlass.range_constexpr(self.num_stages_Q): + if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): + if stage == 0 or m_block + stage < m_block_max: + load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + if cutlass.const_expr(stage < self.num_stages_dO): + if stage == 0 or m_block + stage < m_block_max: + load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() - # /////////////////////////////////////////////////////////////////////////////// - # Mainloop - # /////////////////////////////////////////////////////////////////////////////// - # Start processing of the first n-block. - mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( - mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, - mask_seqlen=True, mask_causal=self.is_causal - ) - smem_pipe_read_q = cutlass.Int32(0) - smem_pipe_read_do = cutlass.Int32(0) - smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) - smem_pipe_write_do = cutlass.Int32(0) - for m_tile in cutlass.range_dynamic(m_block_min, m_block_max, unroll=1): - compute_one_m_block( - m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, - mask_fn=mask_fn, + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + mask_seqlen=True, mask_causal=self.is_causal ) - smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) - smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) - smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) - smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) + smem_pipe_read_q = cutlass.Int32(0) + smem_pipe_read_do = cutlass.Int32(0) + smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) + smem_pipe_write_do = cutlass.Int32(0) + for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): + compute_one_m_block( + m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, + mask_fn=mask_fn, + ) + smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) + smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) + smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) + smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - # If GQA, we scale dK in the postprocessing kernel instead - if cutlass.const_expr(self.qhead_per_kvhead == 1): - acc_dK.store(acc_dK.load() * softmax_scale) - # reuse sK and sV data iterator - sdK = cute.make_tensor(sK.iterator, sK_layout) - sdV = cute.make_tensor(sV.iterator, sV_layout) - self.epilogue( - acc_dK, acc_dV, mdK, mdV, sdK, sdV, - gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, - tidx, n_block, head_idx, batch_idx - ) + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # If GQA, we scale dK in the postprocessing kernel instead + if cutlass.const_expr(self.qhead_per_kvhead == 1): + acc_dK.store(acc_dK.load() * softmax_scale) + # reuse sK and sV data iterator + sdK = cute.make_tensor(sK.iterator, sK_layout) + sdV = cute.make_tensor(sV.iterator, sV_layout) + self.epilogue( + acc_dK, acc_dV, mdK, mdV, sdK, sdV, + gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, + tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v + ) @cute.jit def compute_one_m_block( @@ -738,7 +843,7 @@ def compute_one_m_block( mask_fn: Optional[Callable] = None, ): def load_Q_next(): - m_block_next = m_block + (self.num_stages_Q - 1 if self.num_stages_Q > 1 else 1) + m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1) if m_block_next < m_block_max: load_Q_LSE(m_block_next, smem_pipe_write_q) cute.arch.cp_async_commit_group() @@ -750,22 +855,22 @@ def load_dO_next(): # MMA S acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C( - (self.m_block_size, self.n_block_size) if not self.SdP_swapAB else (self.n_block_size, self.m_block_size) + (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size) ) acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_S.fill(0.0) - cute.arch.cp_async_wait_group(1 if self.num_stages_Q > 1 else 0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK, - smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], + smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.tSsK, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( - smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], tLSErLSE + smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) @@ -774,31 +879,31 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) - for r in range(cute.size(acc_S_mn, mode=[0])): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # MMA dP acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_dP.fill(0.0) - cute.arch.cp_async_wait_group(1 if self.num_stages_dO > 1 else 0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV, - smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], + smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.tdPsV, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, - hook_fn=load_Q_next if self.num_stages_Q > 1 else None, + hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None, swap_AB=self.SdP_swapAB, ) tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0]) cute.autovec_copy( - smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], tLSErdPsum + smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum ) acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) - for r in range(cute.size(acc_dP_mn, mode=[0])): + for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) @@ -823,7 +928,7 @@ def load_dO_next(): sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO, smem_copy_params.tdVsPt, - smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], + smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, @@ -834,7 +939,7 @@ def load_dO_next(): # MMA dQ def dQ_mma(hook_fn): acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C( - (self.m_block_size, self.head_dim_padded) if not self.dQ_swapAB else (self.head_dim_padded, self.m_block_size) + (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size) ) acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32) acc_dQ.fill(0.0) @@ -849,8 +954,7 @@ def dQ_mma(hook_fn): acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ) tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) - # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(acc_dQ) - for i in range(cute.size(acc_dQ_atomic)): + for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) @@ -867,7 +971,7 @@ def dQ_mma(hook_fn): sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ, smem_copy_params.tdKsdSt, - smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], + smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, @@ -894,6 +998,9 @@ def epilogue( n_block: cutlass.Int32, num_head: cutlass.Int32, batch_size: cutlass.Int32, + seqlen: SeqlenInfoQK, + d_head: cutlass.Int32, + d_head_v: cutlass.Int32 ): rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) @@ -902,6 +1009,9 @@ def epilogue( gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + batch_idx = batch_size + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + if cutlass.const_expr(self.qhead_per_kvhead == 1): # Make sure all threads have finished reading K and V, otherwise we get racy dQ # because smem_q could be changed. @@ -910,7 +1020,7 @@ def epilogue( smem_copy_atom_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width ) - smem_thr_copy_dKV = utils.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) taccdVrdV = smem_thr_copy_dKV.retile(rdV) taccdKrdK = smem_thr_copy_dKV.retile(rdK) taccdVsdV = smem_thr_copy_dKV.partition_D(sdV) @@ -919,10 +1029,16 @@ def epilogue( cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)] + else: + mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)] + blkdK_shape = (self.n_block_size, self.head_dim_padded) blkdV_shape = (self.n_block_size, self.head_dim_v_padded) - gdK = cute.local_tile(mdK[batch_size, None, num_head, None], blkdK_shape, (n_block, 0)) - gdV = cute.local_tile(mdV[batch_size, None, num_head, None], blkdV_shape, (n_block, 0)) + gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0)) + gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0)) tdKsdK = gmem_thr_copy_dK.partition_S(sdK) tdKgdK = gmem_thr_copy_dK.partition_D(gdK) tdVsdV = gmem_thr_copy_dV.partition_S(sdV) @@ -947,44 +1063,52 @@ def epilogue( cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) tdVcdV = gmem_thr_copy_dV.partition_S(cdV) t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) - tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[3]) + tdKpdK = utils.predicate_k(tdKcdK, limit=d_head) if cutlass.const_expr(self.same_hdim_kv): tdVpdV = tdKpdK else: - tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[3]) + tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v) # copy acc dK and acc_dV from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])): - if t0dKcdK[0, rest_m, 0][0] < mdK.shape[1] - n_block * self.n_block_size - tdKcdK[0][0]: + if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]: cute.copy( gmem_tiled_copy_dK, tdKrdK[None, rest_m, None], tdKgdK[None, rest_m, None], - pred=tdKpdK[None, rest_m, None] if self.check_hdim_oob else None, + pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): - if t0dVcdV[0, rest_m, 0][0] < mdV.shape[1] - n_block * self.n_block_size - tdVcdV[0][0]: + if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]: cute.copy( gmem_tiled_copy_dV, tdVrdV[None, rest_m, None], tdVgdV[None, rest_m, None], - pred=tdVpdV[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, ) else: # qhead_per_kvhead > 1, do atomic add # For Sm90, we need to sync to avoid racy writes to smem_q # For Sm80, we don't need to sync since we're not touching smem - num_head_kv = num_head // self.qhead_per_kvhead - gdV = cute.local_tile(mdV[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_v_padded,), (n_block,)) - gdK = cute.local_tile(mdK[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_padded,), (n_block,)) + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)] + else: + padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size + mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None]) + mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None]) + + gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,)) + gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,)) tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV) tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK) acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV) acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) - for i in range(cute.size(acc_dV_atomic)): + for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) - for i in range(cute.size(acc_dK_atomic)): + for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) @cute.jit @@ -1005,16 +1129,16 @@ def load_K( tKcK = gmem_thr_copy.partition_S(cK) t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK) tKpK = utils.predicate_k(tKcK, limit=headdim) - for n in range(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size: # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0] predicate = cute.make_fragment_like(tKpK[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tKpK[i, n, k] if self.check_hdim_oob else True) and predicate_n + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate, ) @@ -1034,16 +1158,16 @@ def load_V( tVcV = gmem_thr_copy.partition_S(cV) t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV) tVpV = utils.predicate_k(tVcV, limit=headdim) - for n in range(cute.size(tVsV.shape[1])): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: # Instead of using tVcV, we using t0VcV and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time. predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0] predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if self.check_hdim_oob else True) and predicate_n + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate, ) @@ -1065,31 +1189,31 @@ def load_Q_LSE( smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size: # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0] predicate = cute.make_fragment_like(tQpQ[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tQpQ[i, m, k] if self.check_hdim_oob else True) and predicate_m + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_Q, tQgQ[None, m, None, block], - tQsQ[None, m, None, smem_pipe_write_q if self.num_stages_Q > 1 else 0], + tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. - for m in range(cute.size(tLSEsLSE.shape[1])): + for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])): if tLSEcLSE[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_LSE, tLSEgLSE[None, m, block], - tLSEsLSE[None, m, smem_pipe_write_q if self.num_stages_Q > 1 else 0], + tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], ) @cute.jit @@ -1109,29 +1233,29 @@ def load_dO_dPsum( smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): - for m in range(cute.size(tdOsdO.shape[1])): + for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size: # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time. predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0] predicate = cute.make_fragment_like(tdOpdO[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tdOpdO[i, m, k] if self.check_hdim_oob else True) and predicate_m + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_dO, tdOgdO[None, m, None, block], - tdOsdO[None, m, None, smem_pipe_write_q if self.num_stages_dO > 1 else 0], + tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. - for m in range(cute.size(tdPsumgdPsum.shape[1])): + for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])): if tdPsumcdPsum[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_dPsum, tdPsumgdPsum[None, m, block], - tdPsumsdPsum[None, m, smem_pipe_write_q if self.num_stages_dO > 1 else 0], + tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], ) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index ccb33d2c026..14d746ba346 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -2,60 +2,73 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math -from typing import Type +from typing import Callable, Optional, Type, Literal import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass.cute.nvgpu import cpasync, warp +import cutlass.utils.hopper_helpers as sm90_utils_basic +import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass import Float32, const_expr +from cutlass.utils import LayoutEnum -from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments, +) class FlashAttentionBackwardPostprocess: def __init__( self, dtype: Type[cutlass.Numeric], - # tiled_mma: cute.TiledMma, head_dim: int, - m_block_size: int = 128, + arch: Literal[80, 90, 100], + tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, dQ_swapAB: bool = False, ): - """Initializes the configuration for a flash attention v2 kernel. - - All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension - should be a multiple of 8. - + """ :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int """ self.dtype = dtype - self.m_block_size = m_block_size + self.tile_m = tile_m + assert arch in [80, 90, 100], ( + "Only Ampere (80), Hopper (90), and Blackwell (100) are supported" + ) + self.arch = arch # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) - self.check_hdim_oob = head_dim != self.head_dim_padded - # self.tiled_mma = tiled_mma + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.tile_hdim self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB @staticmethod - def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :return: True if the kernel can be implemented, False otherwise :rtype: bool @@ -68,114 +81,199 @@ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: return False return True + def _get_tiled_mma(self): + if const_expr(self.arch == 80): + num_mma_warps = self.num_threads // 32 + atom_layout_dQ = ( + (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) + if const_expr(not self.dQ_swapAB) + else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + ) + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + atom_layout_dQ, + permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), + ) + elif const_expr(self.arch == 90): + num_mma_warp_groups = self.num_threads // 128 + atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) + tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + ) + else: + cta_group = tcgen05.CtaGroup.ONE + tiled_mma = sm100_utils_basic.make_trivial_tiled_mma( + self.dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode + Float32, + cta_group, + (self.tile_m, self.tile_hdim), + ) + if const_expr(self.arch in [80, 90]): + assert self.num_threads == tiled_mma.size + return tiled_mma + def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies universal_copy_bits = 128 - async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + async_copy_elems_accum = universal_copy_bits // Float32.width atom_async_copy_accum = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), - cutlass.Float32, + Float32, num_bits_per_copy=universal_copy_bits, ) # We don't do bound checking for the gmem -> smem load so we just assert here. - assert (self.m_block_size * self.head_dim_padded // async_copy_elems_accum) % self.tiled_mma.size == 0 + assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0 self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_async_copy_accum, - cute.make_layout(self.tiled_mma.size), - cute.make_layout(async_copy_elems_accum) - ) - atom_universal_copy_accum = cute.make_copy_atom( - # multiply by 4 for Sm90 - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width, - ) - self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - atom_universal_copy_accum, - cute.make_layout(self.tiled_mma.size), - cute.make_layout(1) # 4 for Sm90 + cute.make_layout(self.num_threads), + cute.make_layout(async_copy_elems_accum), ) + num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 + if const_expr(self.arch == 80): + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + elif const_expr(self.arch == 90): + num_threads_per_warp_group = 128 + num_mma_warp_groups = self.num_threads // 128 + self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout + cute.make_layout(128 // Float32.width), # val_layout + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + ) + else: + self.dQ_reduce_ncol = 32 + dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + assert self.num_threads == 128 # TODO: currently hard-coded + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage) + ) - async_copy_elems = universal_copy_bits // self.dtype.width - # atom_universal_copy: universal copy atom for dQ store - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, - ) - # tdQ_layout: thread layout for dQ store - assert self.head_dim_padded % async_copy_elems == 0 - gmem_threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, - self.tiled_mma.size) - assert self.tiled_mma.size % gmem_threads_per_row == 0 - tdQ_layout = cute.make_ordered_layout( - (self.tiled_mma.size // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), + self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.dtype, self.tile_hdim, self.num_threads ) - # Value layouts for copies - vdQ_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_dQ = cute.make_tiled_copy_tv(atom_universal_copy, tdQ_layout, vdQ_layout) # /////////////////////////////////////////////////////////////////////////////// - # Shared memory layout: dQaccum / dQ + # Shared memory layout: dQ # /////////////////////////////////////////////////////////////////////////////// - self.sdQaccum_layout = cute.make_layout(self.m_block_size * self.head_dim_padded) # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. mma_shape_n = self.tiled_mma.get_tile_size(1) - sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) - self.sdQ_layout = cute.tile_to_shape( - sdQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1) - ) - + if const_expr(self.arch == 80): + sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) + self.sdQ_layout = cute.tile_to_shape( + sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) + ) + elif const_expr(self.arch == 90): + self.sdQ_layout = sm90_utils.make_smem_layout( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) + ) + else: + # TODO: this is hard-coded for hdim 128 + self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1 + ) @cute.jit def __call__( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, - scale: cute.Float32, + scale: cutlass.Float32, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not mdQ.element_type in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + if const_expr(mdQaccum is not None): + if const_expr(mdQaccum.element_type not in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") - num_mma_warps = self.num_threads // 32 - AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if not self.dQ_swapAB else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) - tiled_mma = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), - AtomLayoutdQ, - permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], ) - self.tiled_mma = tiled_mma + mdQaccum, mdQ = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mdQaccum, mdQ) + ] + self.tiled_mma = self._get_tiled_mma() self._setup_attributes() - smem_size = max(cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), - cute.size_in_bytes(self.dtype, self.sdQ_layout)) + smem_size = max( + cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), + cute.size_in_bytes(self.dtype, self.sdQ_layout), + ) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mdQ.shape[1], self.m_block_size), - cute.size(mdQ.shape[2]), - cute.size(mdQ.shape[0]), + if const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mdQ.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_head = mdQ.shape[2] + num_batch = mdQ.shape[0] + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mdQ.shape[1], self.tile_m), + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mdQ.shape[2], + headdim_v=0, + total_q=mdQ.shape[0], + tile_shape_mn=(self.tile_m, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + # grid_dim: (m_block, num_head, batch_size) self.kernel( mdQaccum, mdQ, + mCuSeqlensQ, + mSeqUsedQ, scale, - tiled_mma, + self.tiled_mma, self.dQ_swapAB, self.sdQaccum_layout, self.sdQ_layout, self.g2s_tiled_copy_dQaccum, self.s2r_tiled_copy_dQaccum, self.gmem_tiled_copy_dQ, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, - block=[tiled_mma.size, 1, 1], + block=[self.num_threads, 1, 1], smem=smem_size, stream=stream, ) @@ -185,7 +283,9 @@ def kernel( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, - scale: cute.Float32, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + scale: cutlass.Float32, tiled_mma: cute.TiledMma, dQ_swapAB: cutlass.Constexpr, sdQaccum_layout: cute.Layout, @@ -193,96 +293,437 @@ def kernel( g2s_tiled_copy_dQaccum: cute.TiledCopy, s2r_tiled_copy_dQaccum: cute.TiledCopy, gmem_tiled_copy_dQ: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): - # Thread index, block index - tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() - - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile(mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,)) - blkdQ_shape = (self.m_block_size, self.head_dim_padded) - gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) - # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) - sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) - - seqlen_q = mdQ.shape[1] - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) - - # Step 1: load dQaccum from gmem to smem - g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) - tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) - tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum) - # print(tdQgdQaccum) - # print(tdQsdQaccum) - cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) - cute.arch.cp_async_commit_group() - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() - - # Step 2: load dQ from smem to rmem - s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) - tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) - # print(s2r_tiled_copy_dQaccum) - # print(sdQaccum) - # thr_mma = tiled_mma.get_slice(tidx) - # print(tiled_mma) - acc_shape = tiled_mma.partition_shape_C( - (self.m_block_size, self.head_dim_padded) if not dQ_swapAB - else (self.head_dim_padded, self.m_block_size) + sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) + if const_expr(self.arch in [80, 90]): + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + else: + # extra stage dimension + sdQ = cute.make_tensor( + cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), + sdQ_layout.outer, + )[None, None, 0] + sdQt = utils.transpose_view(sdQ) + + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + m_block, num_head, batch_size, _ = work_tile.tile_idx + + if work_tile.is_valid_tile: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + + seqlen = SeqlenInfoQK.create( + batch_size, + mdQ.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) + if const_expr(not seqlen.has_cu_seqlens_q): + mdQ_cur = mdQ[batch_size, None, num_head, None] + mdQaccum_cur = mdQaccum[batch_size, num_head, None] + head_dim = mdQ.shape[3] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.tile_m + mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None]) + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None] + ) + head_dim = mdQ.shape[2] + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.tile_hdim keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) + + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) + gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + + seqlen_q = seqlen.seqlen_q + seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) + + # Step 1: load dQaccum from gmem to smem + g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) + cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Step 2: load dQ from smem to rmem + s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + tile_shape = (self.tile_m, self.tile_hdim) + acc = None + tiled_copy_t2r = None + if const_expr(self.arch in [80, 90]): + acc_shape = tiled_mma.partition_shape_C( + tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + else: + thr_mma = tiled_mma.get_slice(0) # 1-CTA + dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) + tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) + tdQcdQ = thr_mma.partition_C( + cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + ) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) + # Convert tdQrdQaccum from fp32 to fp16/bf16 + rdQ = cute.make_fragment_like(acc, self.dtype) + rdQ.store((acc.load() * scale).to(self.dtype)) + + # Step 3: Copy dQ from register to smem + cute.arch.barrier() # make sure all threads have finished loading dQaccum + if const_expr(self.arch in [80, 90]): + copy_atom_r2s_dQ = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.dQ_swapAB + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) + else: + # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( + # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, + # ) + # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) + thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads + val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) + copy_atom_r2s_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( + copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ + ) + thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + if const_expr(self.arch in [80, 90]): + taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) + else: + taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape + taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) + taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt) + cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) + + # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + cute.arch.barrier() # make sure all smem stores are done + gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) + tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) + tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) + tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) + # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled + cute.autovec_copy(tdQsdQ, tdQrdQ) + + # Step 5: Copy dQ from register to gmem + tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) + tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) + for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): + if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m: + cute.copy( + gmem_tiled_copy_dQ, + tdQrdQ[None, rest_m, None], + tdQgdQ[None, rest_m, None], + pred=tdQpdQ[None, rest_m, None], + ) + + +class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess): + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + tile_m: int = 128, + num_threads: int = 256, + AtomLayoutMdQ: int = 1, + dQ_swapAB: bool = False, + ): + super().__init__( + dtype=dtype, + head_dim=head_dim, + arch=90, # tmp dummy placement for now + tile_m=tile_m, + num_threads=num_threads, + AtomLayoutMdQ=AtomLayoutMdQ, + dQ_swapAB=dQ_swapAB, ) - acc = cute.make_fragment(acc_shape, cutlass.Float32) - assert cute.size(acc) == cute.size(tdQsdQaccum) - tdQrdQaccum = s2r_thr_copy_dQaccum.retile(acc) - # Somehow even after retiling the layouts of tdQsdQaccum and tdQrdQaccum are different. - # So we have to do a for loop to copy - # cute.copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum) - # print(acc) - # print(tdQsdQaccum) # ((1, 1), 64) - # print(tdQrdQaccum) # ((1, 4), 4, 4) - for i in range(cute.size(tdQsdQaccum)): - tdQrdQaccum[i] = tdQsdQaccum[i] - # Convert tdQrdQaccum from fp32 to fp16/bf16 - rdQ = cute.make_fragment_like(acc, self.dtype) - rdQ.store((acc.load() * scale).to(self.dtype)) - - # Step 3: Copy dQ from register to smem - cute.arch.barrier() # make sure all threads have finished loading dQaccum - smem_copy_atom_dQ = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width + + def _setup_attributes(self): + self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128 + + self.sdQaccum_layout = cute.make_layout( + shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32) ) - smem_thr_copy_dQ = utils.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) - taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) - taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) - cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) - # print(taccdQrdQ) - # print(taccdQsdQ) - - # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem - gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) - tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) - tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) - tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) - cute.arch.barrier() # make sure all smem stores are done - # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled - cute.autovec_copy(tdQsdQ, tdQrdQ) - - # Step 5: Copy dQ from register to gmem - cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) - tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) - for rest_m in cutlass.range_constexpr(cute.size(tdQrdQ.shape[1])): - if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: - cute.copy( - gmem_tiled_copy_dQ, - tdQrdQ[None, rest_m, None], - tdQgdQ[None, rest_m, None], - pred=tdQpdQ[None, rest_m, None], - ) + self.epi_tile_q = (self.tile_m, self.tile_hdim) + self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( + self.dtype, + LayoutEnum.ROW_MAJOR, + self.epi_tile_q, + 1, + ) + + @cute.jit + def __call__( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + stream: cuda.CUstream, + ): + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mdQaccum, mdQ = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mdQaccum, mdQ) + ] + # (b, h, s*d) -> (s*d, h, b) + mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0])) + # (b, s, h, d) -> (s, d, h, b) + mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0])) + + self._setup_attributes() + + grid_dim = [ + cute.ceil_div(mdQ.shape[0], self.tile_m), + cute.size(mdQ.shape[2]), + cute.size(mdQ.shape[3]), + ] + + cta_group = tcgen05.CtaGroup.ONE + self.mma_tiler_dsk = (self.tile_m, self.tile_hdim) + + dS_major_mode = tcgen05.OperandMajorMode.MN + kt_major_mode_dsq = tcgen05.OperandMajorMode.MN + + tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( + cutlass.BFloat16, + dS_major_mode, + kt_major_mode_dsq, + cutlass.Float32, + cta_group, + self.mma_tiler_dsk, + ) + + dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + mdQ, + cute.select(self.sdQ_layout, mode=[0, 1]), + dQ_cta_v_layout, + ) + + buffer_align_bytes = 1024 + + @cute.struct + class SharedStorage: + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)], + 128, + ] + + sdQ: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)], + buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + mdQaccum, + tma_tensor_dQ, + tma_atom_dQ, + self.sdQaccum_layout, + self.sdQ_layout, + tiled_mma_dsk, + scale, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + tiled_mma_dsk: cute.TiledMma, + scale: cutlass.Float32, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + m_block, head_idx, batch_idx = cute.arch.block_idx() + + # SMEM + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + swz128 = cute.make_swizzle(3, 4, 3) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) + + sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner) + + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + mdQ_cur = mdQ[None, None, head_idx, batch_idx] + + thr_mma_dsk = tiled_mma_dsk.get_slice(tidx) + dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout) + + tmem_ld_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32 + ) + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) + tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) + + num_reduce_warps = 4 + num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128 + ) + tiler_mn, layout_tv = cute.make_layout_tv( + thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1), + val_layout=cute.make_layout(shape=4, stride=1), + ) + G2S_tiled_copy_dQaccum = cute.make_tiled_copy( + atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn + ) + + smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx) + + # S->R + tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) + tiled_smem_store_s2r = cute.make_tiled_copy( + atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn + ) + + s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx) + tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape) + + # R->S + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld + ) + tiled_smem_store_r2s = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_ld.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld.tiler_mn, + ) + tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) + tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) + + num_stages = cute.size(tdQrdQ_t2r, mode=[1]) + for stage in cutlass.range_constexpr(num_stages): + # G->S + gdQaccum_stage = cute.local_tile( + gdQaccum, + (self.tile_m * 32,), + (stage,), + ) + + gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0)) + gdQaccum_stage_g2s = cute.make_tensor( + cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s + ) + + tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s) + tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum) + + cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0]) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) + + # S -> R + tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None] + tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0] + tdQrdQ_r2s_cpy = cute.make_tensor( + tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape) + ) + + cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) + + # R->S + tdQrdQ_r2s_cpy = cute.make_tensor( + cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), + tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape, + ) + dQ_vec = tdQrdQ_r2s_cpy.load() * scale + tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype)) + + cute.copy( + tiled_smem_store_r2s, + tdQrdQ_r2s[None, None, None, None, 0], + tdQsdQ_r2s[None, None, None, None, 0], + ) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) + + # S-> G + gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) + tdQsdQ, tdQgdQ = cpasync.tma_partition( + tma_atom_dQ, + 0, + cute.make_layout(1), + cute.group_modes(sdQ, 0, 2), + cute.group_modes(gdQ, 0, 2), + ) + + cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block]) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 21f209ed97f..985391a7898 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -3,14 +3,23 @@ # from Cutlass C++ to Cute-DSL. import math import operator -from typing import Type, Optional +from typing import Callable, Type, Optional import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute +from cutlass import Float32 from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments, +) class FlashAttentionBackwardPreprocess: @@ -73,33 +82,25 @@ def _setup_attributes(self): # Thread layouts for copies # We want kBlockKGmem to be a power of 2 so that when we do the summing, # it's just between threads in the same warp - gmem_k_block_size = 128 if self.head_dim_padded % 128 == 0 else (64 if self.head_dim_padded % 64 == 0 else (32 if self.head_dim_padded % 32 == 0 else 16)) - universal_copy_bits = 128 - async_copy_elems = universal_copy_bits // self.dtype.width - # atom_universal_copy: universal copy atom for O & dO load - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, - ) - # tOdO_layout: thread layout for O & dO load - self.gmem_threads_per_row = gmem_k_block_size // async_copy_elems - assert self.num_threads % self.gmem_threads_per_row == 0 - tOdO_layout = cute.make_ordered_layout( - (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row), order=(1, 0), + gmem_k_block_size = ( + 128 + if self.head_dim_padded % 128 == 0 + else ( + 64 + if self.head_dim_padded % 64 == 0 + else (32 if self.head_dim_padded % 32 == 0 else 16) + ) ) - # Value layouts for copies - vOdO_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tOdO_layout, vOdO_layout) - self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv(atom_universal_copy, tOdO_layout, vOdO_layout) - - async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width - atom_universal_copy_accum = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits, + self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( + self.dtype, gmem_k_block_size, self.num_threads ) - assert (self.m_block_size * self.head_dim_padded // async_copy_elems_accum) % self.num_threads == 0 - self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - atom_universal_copy_accum, - cute.make_layout(self.num_threads), - cute.make_layout(async_copy_elems_accum) + universal_copy_bits = 128 + num_copy_elems_dQaccum = universal_copy_bits // Float32.width + assert ( + self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum + ) % self.num_threads == 0 + self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_copy_elems_dQaccum ) @cute.jit @@ -111,33 +112,67 @@ def __call__( mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 if cutlass.const_expr(not (mO.element_type == mdO.element_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(not mO.element_type in [cutlass.Float16, cutlass.BFloat16]): + if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(not mdPsum.element_type in [cutlass.Float32]): + if cutlass.const_expr(mdPsum.element_type not in [Float32]): raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + if cutlass.const_expr(mdQaccum.element_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") if cutlass.const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" - if cutlass.const_expr(not mLSE.element_type in [cutlass.Float32]): + if cutlass.const_expr(mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(not mLSElog2.element_type in [cutlass.Float32]): + if cutlass.const_expr(mLSElog2.element_type not in [Float32]): raise TypeError("LSElog2 tensor must be Float32") + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO, mdO, mdQaccum = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mO, mdO, mdQaccum) + ] + self._setup_attributes() - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mO.shape[1], self.m_block_size), - cute.size(mO.shape[2]), - cute.size(mO.shape[0]), + if cutlass.const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mO.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_head = mO.shape[2] + num_batch = mO.shape[0] + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mO.shape[1], self.m_block_size), + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=0, + headdim_v=mO.shape[2], + total_q=mO.shape[0], + tile_shape_mn=(self.m_block_size, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + self.kernel( mO, mdO, @@ -145,9 +180,12 @@ def __call__( mLSE, mLSElog2, mdQaccum, + mCuSeqlensQ, + mSeqUsedQ, self.gmem_tiled_copy_O, - self.gmem_tiled_copy_dO, self.gmem_tiled_copy_dQaccum, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -163,99 +201,163 @@ def kernel( mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, - gmem_tiled_copy_dO: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkOdO_shape = (self.m_block_size, self.head_dim_padded) - # (m_block_size, head_dim) - gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) - gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + m_block, num_head, batch_size, _ = work_tile.tile_idx - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - gmem_thr_copy_dO = gmem_tiled_copy_dO.get_slice(tidx) - # (CPY_Atom, CPY_M, CPY_K) - tOgO = gmem_thr_copy_O.partition_S(gO) - tOgdO = gmem_thr_copy_dO.partition_S(gdO) + if work_tile.is_valid_tile: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + seqlen = SeqlenInfoQK.create( + batch_size, + mO.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cOdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tOcO = gmem_thr_copy_O.partition_S(cOdO) - t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cOdO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[3]) - tOcdO = gmem_thr_copy_dO.partition_S(cOdO) - t0OcdO = gmem_thr_copy_dO.get_slice(0).partition_S(cOdO) - tOpdO = utils.predicate_k(tOcdO, limit=mdO.shape[3]) - - seqlen_q = mO.shape[1] - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[batch_size, None, num_head, None] + mdO_cur = mdO[batch_size, None, num_head, None] + mdPsum_cur = mdPsum[batch_size, num_head, None] + headdim_v = mO.shape[3] + else: + mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None]) - if cutlass.const_expr(mLSE is not None): - gLSE = cute.local_tile(mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,)) - lse = cutlass.Float32.inf - if tidx < seqlen_q - m_block * self.m_block_size: - lse = gLSE[tidx] - - tOrO = cute.make_fragment_like(tOgO) - tOrdO = cute.make_fragment_like(tOgdO) - assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) - assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) - assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) - for m in range(cute.size(tOrO.shape[1])): - # Instead of using tOcO, we using t0OcO and subtract the offset from the limit - # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. - if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: - cute.copy( - gmem_thr_copy_O, - tOgO[None, m, None], - tOrO[None, m, None], - pred=tOpO[None, m, None] if self.check_hdim_oob else None, - ) - cute.copy( - gmem_thr_copy_dO, - tOgdO[None, m, None], - tOrdO[None, m, None], - pred=tOpdO[None, m, None] if self.check_hdim_oob else None, - ) - # Sum across the "k" dimension - dpsum = ( - tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32) - ).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)) - dpsum = utils.warp_reduce(dpsum, operator.add, width=self.gmem_threads_per_row) - dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), cutlass.Float32) - dP_sum.store(dpsum) - - # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile(mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,)) - # Only the thread corresponding to column 0 writes out the lse to gmem - if tOcO[0, 0, 0][1] == 0: - for m in cutlass.range_constexpr(cute.size(dP_sum)): - row = tOcO[0, m, 0][0] - gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0 - - # Clear dQaccum - if cutlass.const_expr(mdQaccum is not None): - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile(mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,)) - gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) - tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - zero = cute.make_fragment_like(tQgQaccum) - zero.fill(0.0) - cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None]) + headdim_v = mO.shape[2] - if cutlass.const_expr(mLSE is not None): - gLSElog2 = cute.local_tile(mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,)) - LOG2_E = math.log2(math.e) - if tidx < seqlen_q_rounded - m_block * self.m_block_size: - gLSElog2[tidx] = lse * LOG2_E if lse != -cutlass.Float32.inf else 0.0 + blkOdO_shape = (self.m_block_size, self.head_dim_padded) + # (m_block_size, head_dim) + gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) + gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0)) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tOgO = gmem_thr_copy_O.partition_S(gO) + tOgdO = gmem_thr_copy_O.partition_S(gdO) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=headdim_v) + tOpdO = utils.predicate_k(tOcO, limit=headdim_v) + + seqlen_q = seqlen.seqlen_q + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[batch_size, num_head, None] + else: + mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None]) + + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) + lse = Float32.inf + if tidx < seqlen_q - m_block * self.m_block_size: + lse = gLSE[tidx] + + tOrO = cute.make_fragment_like(tOgO) + tOrdO = cute.make_fragment_like(tOgdO) + assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) + assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) + assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit + # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_thr_copy_O, + tOgO[None, m, None], + tOrO[None, m, None], + pred=tOpO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, + ) + cute.copy( + gmem_thr_copy_O, + tOgdO[None, m, None], + tOrdO[None, m, None], + pred=tOpdO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, + ) + # Sum across the "k" dimension + dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) + ) + threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] + assert cute.arch.WARP_SIZE % threads_per_row == 0 + dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) + dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) + dP_sum.store(dpsum) + + # Write dPsum from rmem -> gmem + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,)) + # Only the thread corresponding to column 0 writes out the dPsum to gmem + if tOcO[0, 0, 0][1] == 0: + for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + row = tOcO[0, m, 0][0] + gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0 + + # Clear dQaccum + if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[batch_size, num_head, None] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None] + ) + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.head_dim_padded keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) + + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + zero = cute.make_fragment_like(tdQgdQaccum) + zero.fill(0.0) + cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) + + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSElog2_cur = mLSElog2[batch_size, num_head, None] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None]) + + gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) + LOG2_E = math.log2(math.e) + if tidx < seqlen_q_rounded - m_block * self.m_block_size: + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py new file mode 100644 index 00000000000..00c8cbf66d7 --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -0,0 +1,2538 @@ +# Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao. +import math +from typing import Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.utils import LayoutEnum +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass.pipeline import PipelineAsync, PipelineConsumer + +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTBwdScheduler, # noqa + ParamsBase, +) + +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 + + +class FlashAttentionBackwardSm100: + arch = 100 + + def __init__( + self, + head_dim: int, + head_dim_v: Optional[int] = None, + is_causal: bool = False, + is_local: bool = False, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + tile_m: int = 128, + tile_n: int = 128, + is_persistent: bool = False, + deterministic: bool = False, + cluster_size: int = 1, + ): + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + assert self.tile_hdim == self.tile_hdimv, ( + "tile_hdim and tile_hdimv must be the same for now" + ) + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + + self.tile_m = tile_m + self.tile_n = tile_n + + # CTA tiler + self.cta_tiler = (tile_m, tile_n, self.tile_hdim) + # S = K @ Q.T + self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) + # dP = V @ dO.T + self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv) + # dV = P.T @ dO + self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m) + # dK = dS.T @ Q (N, M) (M, D) + self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m) + # dQ = dS @ K + self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) + + self.acc_dtype = Float32 + + assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported" + self.cluster_shape_mn = (cluster_size, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = False + self.use_tma_store = True + self.deterministic = deterministic + + # Speed optimizations, does not affect correctness + self.shuffle_LSE = False + self.shuffle_dPsum = False + self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal + + self.reduce_warp_ids = (0, 1, 2, 3) + self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epi_warp_id = 14 + self.empty_warp_id = 15 + + # 16 warps -> 512 threads + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.reduce_warp_ids, + *self.compute_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epi_warp_id, + self.empty_warp_id, + ) + ) + + # NamedBarrier + self.compute_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE, + ) + # self.epilogue_sync_barrier = pipeline.NamedBarrier( + # barrier_id=2, + # num_threads=self.num_compute_warps * self.threads_per_warp, + # ) + self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE, + ) + + # TMEM setup + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + # self.tmem_dK_offset = 0 + # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim + # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv + # self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ + # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim) + # self.tmem_P_offset = self.tmem_S_offset # overlap with S + # self.tmem_total = self.tmem_S_offset + self.tile_n + # assert self.tmem_total <= self.tmem_alloc_cols + + self.tmem_S_offset = 0 + self.tmem_P_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_S_offset + self.tile_n + self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv + self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP + self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m + self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP + + if (not is_causal and not is_local) or deterministic: + self.num_regs_reduce = 152 + self.num_regs_compute = 136 + else: + self.num_regs_reduce = 136 + self.num_regs_compute = 144 + self.num_regs_other = 96 - 8 + self.num_regs_empty = 24 + assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.Q_stage = 2 + self.dO_stage = 1 + # LSE_stage = Q_stage and dPsum_stage = dO_stage + # self.sdKVaccum_stage = 2 + # number of tma reduce adds per dQacc mma + self.dQ_reduce_ncol = 32 + self.sdQaccum_stage = 64 // self.dQ_reduce_ncol + assert self.tile_hdim % self.dQ_reduce_ncol == 0 + self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 + # number of tma reduce adds for dKacc and dVacc epilogue + self.dK_reduce_ncol = 32 + + def _get_tiled_mma(self): + cta_group = tcgen05.CtaGroup.ONE + # S = K @ Q.T + tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.mma_tiler_kq[:2], + ) + # dP = V @ dO.T + tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.mma_tiler_vdo[:2], + ) + # dV += P @ dO --> (K, MN) major + tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, # P_major_mode + tcgen05.OperandMajorMode.MN, # dO_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_pdo[:2], + a_source=tcgen05.OperandSource.TMEM, + ) + # dK += dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dK_a_src = tcgen05.OperandSource.SMEM + else: + mma_dK_a_src = tcgen05.OperandSource.TMEM + tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Q_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_dsq[:2], + a_source=mma_dK_a_src, + ) + # dQ = dS @ K + tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( + self.k_dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_dsk[:2], + ) + return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + + def _setup_smem_layout(self): + # S = K @ Q.T + sK_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_S, + self.mma_tiler_kq, + self.k_dtype, + 1, + ) + self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) + self.sQ_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_S, + self.mma_tiler_kq, + self.q_dtype, + self.Q_stage, + ) + # dP = V @ dO.T + sV_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dP, + self.mma_tiler_vdo, + self.v_dtype, + 1, + ) + self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) + self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dP, + self.mma_tiler_vdo, + self.do_dtype, + self.dO_stage, + ) + # dV += P @ dO + tP_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + 1, + ) + self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0)) + self.sdO_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + self.dO_stage, + ) + # dK += dS.T @ Q + sdSt_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.ds_dtype, + 1, + ) + self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) + tdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.ds_dtype, + 1, + ) + self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0)) + self.sQt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.q_dtype, + self.Q_stage, + ) + # dQ = dS @ K + sdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dQ, + self.mma_tiler_dsk, + self.ds_dtype, + 1, + ) + self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0)) + sKt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dQ, + self.mma_tiler_dsk, + self.k_dtype, + 1, + ) + self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0)) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) + ) + self.sLSE_layout = cute.make_layout( + shape=(self.tile_m, self.Q_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + self.sdPsum_layout = cute.make_layout( + shape=(self.tile_m, self.dO_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + self.sdKV_epi_tile = ( + self.tile_n, + min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 + ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + # headdim_64 gets 1 stage + self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1]) + self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages + # TODO: dK and dV could have different shapes + if const_expr(self.qhead_per_kvhead == 1): + self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dk_dtype, + LayoutEnum.ROW_MAJOR, + self.sdKV_epi_tile, + 2, # num compute wgs + ) + else: + self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + ): + assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( + "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" + ) + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.do_dtype = mdO.element_type + self.lse_dtype = mLSE.element_type + self.dpsum_dtype = mdPsum.element_type + self.dqaccum_dtype = mdQaccum.element_type + self.dk_dtype = mdK.element_type + self.dv_dtype = mdV.element_type + self.ds_dtype = self.q_dtype + + if const_expr(self.qhead_per_kvhead > 1): + assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" + assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + ( + mdQaccum, + mdK, + mdV, + ) = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in ( + mdQaccum, + mdK, + mdV, + ) + ] + + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)] + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + mLSE, mdPsum, mdQaccum = [ + utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) + ] + if const_expr(self.qhead_per_kvhead == 1): + layout_dKV_transpose = layout_transpose + else: + layout_dKV_transpose = LSE_dPsum_dQaccum_transpose + mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] + dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, b) + mdO = utils.select(mdO, mode=dO_transpose) + + semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) + + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None + + self._setup_attributes() + ( + self.tiled_mma_S, + self.tiled_mma_dP, + self.tiled_mma_dK, + self.tiled_mma_dV, + self.tiled_mma_dQ, + ) = self._get_tiled_mma() + self._setup_smem_layout() + + cta_group = tcgen05.CtaGroup.ONE + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (self.tiled_mma_S.thr_id.shape,), + ) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_q_do_mcast = self.num_mcast_ctas_b > 1 + + if const_expr(self.qhead_per_kvhead == 1): + self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) + self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) + dK_major_mode = self.mdK_layout_enum.mma_major_mode() + dV_major_mode = self.mdV_layout_enum.mma_major_mode() + if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdK is wrong") + if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdV is wrong") + + if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1): + tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKV, + mdK, + cute.select(self.sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, + 1, # no mcast + ) + tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKV, + mdV, + cute.select(self.sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, + 1, # no mcast + ) + else: + mdV_tma_tensor = mdV + mdK_tma_tensor = mdK + tma_atom_dV = None + tma_atom_dK = None + + if const_expr(self.qhead_per_kvhead == 1): + thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads + val_layout_r2s_dKV = cute.make_ordered_layout( + (1, 128 // self.dk_dtype.width), order=(1, 0) + ) # 4 or 8 vals for 16 byte store + copy_atom_r2s_dKV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( + copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV + ) + else: + tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d( + Float32, 128, num_copy_elems=128 // Float32.width + ) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) + + # S.T = K @ Q.T + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mK, + cute.select(self.sK_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + self.tiled_mma_S, + self.cluster_layout_vmnk.shape, + ) + Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_S.thr_id + ) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + Q_tma_op, + mQ, + cute.select(self.sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + self.tiled_mma_S, + self.cluster_layout_vmnk.shape, + ) + # dP.T = V @ dO.T + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mV, + cute.select(self.sV_layout, mode=[0, 1, 2]), + self.mma_tiler_vdo, + self.tiled_mma_dP, + self.cluster_layout_vmnk.shape, + ) + dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_dV.thr_id + ) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + dO_tma_op, + mdO, + cute.select(self.sdO_layout, mode=[0, 1, 2]), + self.mma_tiler_pdo, + self.tiled_mma_dV, + self.cluster_layout_vmnk.shape, + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 + self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 + + # TileScheduler = SingleTileScheduler + if const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + # reads n_blocks right-to-left + self.spt = (self.is_causal or self.is_local) and self.deterministic + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), # num_heads = num_query_heads + cute.size(mK.shape[3]), + 1, # num_splits + cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]), + tile_shape_mn=self.cta_tiler[:2], + cluster_shape_mn=self.cluster_shape_mnk[:2], + mCuSeqlensQ=None, + mSeqUsedQ=None, + qhead_per_kvhead_packgqa=1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=self.spt, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + # cute.printf("grid_dim = {}", grid_dim) + + # Compute allocation sizes for shared buffers that are reused + # sQ is reused for sdK, sdO is reused for sdV + sQ_alloc_bytes = max( + cute.size_in_bytes(self.q_dtype, self.sQ_layout), + cute.size_in_bytes(self.dk_dtype, self.sdKV_layout), + ) + sdO_alloc_bytes = max( + cute.size_in_bytes(self.dv_dtype, self.sdKV_layout), + cute.size_in_bytes(self.do_dtype, self.sdO_layout), + ) + # Sanity check that layouts fit in allocation + sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout) + sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout) + assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation" + assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation" + + @cute.struct + class SharedStorage: + Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] + dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + tmem_holding_buf: Int32 + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + + # Smem tensors + + # sQ is reused for sdK which in the non-MHA case needs float32 + sQ: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], + self.buffer_align_bytes, + ] + # sdO is reused for sdV which in the non-MHA case needs float32 + sdO: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], + 128, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], + 128, + ] + sdPsum: cute.struct.Align[ + cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], + 128, + ] + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + mLSE, + mdPsum, + tma_tensor_dO, + mdV, + mdK, + mdQaccum, + mdV_tma_tensor, + mdK_tma_tensor, + mdQ_semaphore, + mdK_semaphore, + mdV_semaphore, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + tma_atom_dV, + tma_atom_dK, + self.sQ_layout, + self.sQt_layout, + self.sK_layout, + self.sV_layout, + self.sLSE_layout, + self.sdPsum_layout, + self.sdO_layout, + self.sdOt_layout, + self.sdSt_layout, + self.sdS_layout, + self.sKt_layout, + self.sdQaccum_layout, + self.sdKV_layout, + self.tP_layout, + self.tdS_layout, + self.tiled_mma_S, + self.tiled_mma_dP, + self.tiled_mma_dV, + self.tiled_mma_dK, + self.tiled_mma_dQ, + tiled_copy_r2s_dKV, + softmax_scale, + softmax_scale_log2, + window_size_left, + window_size_right, + tile_sched_params, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + mdQaccum: cute.Tensor, + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + mdQ_semaphore: Optional[cute.Tensor], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + sQ_layout: cute.ComposedLayout, + sQt_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sdPsum_layout: cute.Layout, + sdO_layout: cute.ComposedLayout, + sdOt_layout: cute.ComposedLayout, + sdSt_layout: cute.ComposedLayout, + sdS_layout: cute.ComposedLayout, + sKt_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + sdKV_layout: cute.ComposedLayout | cute.Layout, + tP_layout: cute.ComposedLayout, + tdS_layout: cute.ComposedLayout, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + tiled_copy_r2s_dKV: cute.TiledCopy, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + tile_sched_params: ParamsBase, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == self.load_warp_id: + with cute.arch.elect_one(): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + if const_expr(tma_atom_dV is not None): + cpasync.prefetch_descriptor(tma_atom_dV) + if const_expr(tma_atom_dK is not None): + cpasync.prefetch_descriptor(tma_atom_dK) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_S.thr_id.shape,), + ) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr() + dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr() + + if warp_idx == 1: + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + ) + if const_expr(self.cluster_reduce_dQ): + if warp_idx == 4: + for i in range(self.dQaccum_reduce_stage // 2): + cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1) + cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1) + + # UMMA producers and AsyncThread consumers + pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + # Only 1 thread per warp will signal + pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + ) + pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.S_mbar_ptr.data_ptr(), + ) + pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dP_mbar_ptr.data_ptr(), + ) + pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=2, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dKV_mbar_ptr.data_ptr(), + ) + pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, + len(self.reduce_warp_ids), + ) # Compute + pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, + barrier_storage=storage.dQ_mbar_ptr.data_ptr(), + ) + + # AsyncThread producers and UMMA consumers + # Only 1 thread per warp will signal + pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + ) # Compute + pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) # MMA + pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=1, + producer_group=pipeline_PdS_producer_group, + consumer_group=pipeline_PdS_consumer_group, + barrier_storage=storage.dS_mbar_ptr.data_ptr(), + ) + + # TMA producer and UMMA consumers + pipeline_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + # The arrive count is the number of mcast size + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b + ) + pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup( + # cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b + cutlass.pipeline.Agent.Thread, + len(self.compute_warp_ids) * 1, + ) + pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.LSE_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["LSE"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.dPsum_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["dPsum"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_Q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.Q_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=False, + ) + pipeline_dO = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.dO_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["dO"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=True, + ) + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype) + sQt = cute.make_tensor( + cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer + ) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) + sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) + sdO = storage.sdO.get_tensor( + sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype + ) + sdOt = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer + ) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) + if const_expr(self.qhead_per_kvhead == 1): + sdV = storage.sdO.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype + ) + sdK = storage.sQ.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype + ) + else: + sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype) + sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype) + + # Buffer sizing is guaranteed by max(...) in SharedStorage declarations + # for both sQ (reused as sdK) and sdO (reused as sdV) + + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + + # TMEM + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) + # S + thr_mma_S = tiled_mma_S.get_slice(0) + Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) + tStS = thr_mma_S.make_fragment_C(Sacc_shape) + # (MMA, MMA_M, MMA_N) + tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) + # dP + thr_mma_dP = tiled_mma_dP.get_slice(0) + dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape) + tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) + # dV + thr_mma_dV = tiled_mma_dV.get_slice(0) + dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) + tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) + tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) + tP = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer + ) + # dK + thr_mma_dK = tiled_mma_dK.get_slice(0) + dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) + tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) + tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) + tdS = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer + ) + # dQ + thr_mma_dQ = tiled_mma_dQ.get_slice(0) + dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout) + + block_info = BlockInfo( + self.tile_m, + # self.tile_n, + self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + # EMPTY + # (15) + if warp_idx == self.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # EPI + # (14) + if warp_idx == self.epi_warp_id: + # currently no-op, could use for tma store/reduce + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # LOAD + # (13) + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_S, + thr_mma_dP, + thr_mma_dV, + mQ, + mK, + mV, + mLSE, + mdPsum, + mdO, + sQ, + sK, + sV, + sLSE, + sdPsum, + sdO, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + pipeline_Q, + pipeline_dO, + pipeline_LSE, + pipeline_dPsum, + cluster_layout_vmnk, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + should_load_Q=True, + should_load_dO=True, + ) + + # MMA + # (12) + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_S, + tiled_mma_dP, + tiled_mma_dV, + tiled_mma_dK, + tiled_mma_dQ, + sQ, + sQt, + sK, + sV, + sdO, + sdOt, + sdSt, + sdS, + sKt, + tP, + tdS, + tStS, + tdPtdP, + tdVtdV, + tdKtdK, + tdQtdQ, + pipeline_Q.make_consumer(), + pipeline_dO, + pipeline_S_P, + pipeline_dS, + pipeline_dKV, + pipeline_dP, + pipeline_dQ, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + cute.arch.relinquish_tmem_alloc_permit() + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf + ) + + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False) + + # Compute + # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps + if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps + self.compute_loop( + thr_mma_S, + thr_mma_dP, + thr_mma_dV, + thr_mma_dK, + tStS, + sLSE, + sdPsum, + tdVtdV, + tdKtdK, + mdV, + mdK, + sdS, + tdPtdP, + pipeline_LSE, + pipeline_dPsum, + pipeline_S_P, + pipeline_dS, + pipeline_dKV, + pipeline_dP, + softmax_scale, + softmax_scale_log2, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + sdV, + sdK, + mdV_tma_tensor, + mdK_tma_tensor, + tma_atom_dV, + tma_atom_dK, + tiled_copy_r2s_dKV, + mdK_semaphore, + mdV_semaphore, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # Reduce + # (0, 1, 2, 3) - dQ + if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_reduce) + self.dQacc_reduce( + mdQaccum, + sdQaccum, + thr_mma_dQ, + tdQtdQ, + pipeline_dQ, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + mdQ_semaphore, + ) + + return + + @cute.jit + def load( + self, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + sdO: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + pipeline_Q: PipelineAsync, + pipeline_dO: PipelineAsync, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, + cluster_layout_vmnk: cute.Layout, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + should_load_Q: bool = True, + should_load_dO: bool = True, + ): + producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) + + # Compute multicast mask for Q & dO buffer full + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + q_do_mcast_mask = None + if const_expr(self.is_q_do_mcast): + q_do_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + head_idx_kv = head_idx // self.qhead_per_kvhead + mQ_cur = mQ[None, None, head_idx, batch_idx] + mK_cur = mK[None, None, head_idx_kv, batch_idx] + mV_cur = mV[None, None, head_idx_kv, batch_idx] + mdO_cur = mdO[None, None, head_idx, batch_idx] + mLSE_cur = mLSE[None, head_idx, batch_idx] + mPsum_cur = mdPsum[None, head_idx, batch_idx] + + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) + tSgK = thr_mma_S.partition_A(gK) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) + tdPgV = thr_mma_dP.partition_A(gV) + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) + tSgQ = thr_mma_S.partition_B(gQ) + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,)) + gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) + tdPgdO = thr_mma_dV.partition_B(gdO) + + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True + ) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + tdPgV, + sV, + single_stage=True, + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tSgQ, + dst_tensor=sQ, + mcast_mask=q_do_mcast_mask, + ) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tdPgdO, + dst_tensor=sdO, + mcast_mask=q_do_mcast_mask, + ) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) + copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) + copy_stats = partial(cute.copy, copy_atom_stats) + # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32) + # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) + + if const_expr(not self.is_local) or m_block_min < m_block_max: + # First iteration: load K together w Q & LSE, then V together w dO & dPsum + if const_expr(should_load_Q): + # K & Q + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_min, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block_min], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + # V & dO + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_min, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block_min], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + # Q + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + # dO + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + if const_expr(should_load_Q): + pipeline_Q.producer_tail( + producer_state_Q_LSE.clone() + ) # will hang if we don't clone + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma( + self, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sdOt: cute.Tensor, + sdSt: cute.Tensor, + sdS: cute.Tensor, + sKt: cute.Tensor, + tP: cute.Tensor, + tdS: cute.Tensor, + tStS: cute.Tensor, + tdPtdP: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + tdQtdQ: cute.Tensor, + pipeline_Q_consumer: PipelineConsumer, + pipeline_dO: PipelineAsync, + pipeline_S_P: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dKV: PipelineAsync, + pipeline_dP: PipelineAsync, + pipeline_dQ: PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + # [2025-10-21] For reasons I don't understand, putting these partitioning in the main + # kernel (before warp specialization) is a lot slower tha putting them here. + # Partition smem / tmem tensors + # S = K @ Q.T + tSrK = tiled_mma_S.make_fragment_A(sK) + tSrQ = tiled_mma_S.make_fragment_B(sQ) + # dP = V @ dO.T + tdPrV = tiled_mma_dP.make_fragment_A(sV) + tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) + # dK = dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + else: + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) + tdKrQ = tiled_mma_dK.make_fragment_B(sQt) + # dQ = dS @ K + tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) + tdQrK = tiled_mma_dQ.make_fragment_B(sKt) + # dV = P @ dO.T + tdVrdO = tiled_mma_dV.make_fragment_B(sdO) + tdVrP = tiled_mma_dV.make_fragment_A(tP) + + # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True) + mma_qk_fn = partial( + gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True + ) + # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) + mma_dov_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dP, + tdPtdP, + tdPrV, + tdPrdOt, + sA=sV, + sB=sdOt, + zero_init=True, + ) + # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) + mma_pdo_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dV, + tdVtdV, + tdVrP, + tdVrdO, + sA=None, + sB=sdO, + tA_addr=self.tmem_P_offset, + ) + mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True) + # mma_dsk_fn = partial( + # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True + # ) + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + else: + # Need to explicitly pass in tA_addr for correctness + mma_dsq_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dK, + tdKtdK, + tdKrdS, + tdKrQ, + sA=None, + sB=sQt, + tA_addr=self.tmem_dS_offset, + ) + + consumer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + producer_phase_acc = Int32(1) # For S & P, dP, dQ + consumer_state_dS = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + # producer_state_dKV = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 2 + # ) + producer_phase_dKV = Int32(1) + cta_group = pipeline_S_P.cta_group + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + if const_expr(not self.is_local) or m_block_min < m_block_max: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dO.T + # 3. dV = P @ dO + # 1) S = Q0 @ K.T + handle_Q = pipeline_Q_consumer.wait_and_advance() + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=handle_Q.index) + # Don't release Q yet + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + # Don't release dO yet + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 3) dV = P.T @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dO.T + # 5. dV = P.T @ dO + + for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # 1) S = K @ Q_i + handle_Q_next = pipeline_Q_consumer.wait_and_advance() + # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready + mma_qk_fn(B_idx=handle_Q_next.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order + pipeline_dS.consumer_wait(consumer_state_dS) + + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + + # dP uses the same tmem as dQ + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + # 4) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 5) dV += P @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + handle_Q = handle_Q_next + + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # signal to the epilogue that dV is ready + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + ###### Remaining 2 + # ----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + # signal to the epilogue that dK is ready + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + producer_phase_dKV ^= 1 + + # 2) dQ = dS @ K + # dS is done, so dP must have been ready, we don't need to wait + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier + handle_Q.release() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + producer_phase_acc ^= 1 + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # Currently it hangs if we have this S_P.producer_tail, will need to understand why + # pipeline_S_P.producer_tail(producer_state_S_P) + # pipeline_dP.producer_tail(producer_state_dP) + # pipeline_dKV.producer_tail(producer_state_dKV) + # pipeline_dQ.producer_tail(producer_state_dQ) + + @cute.jit + def split_wg( + self, + t: cute.Tensor, + wg_idx: cutlass.Int32, + num_wg: cutlass.Constexpr[int], + ): + reduced_shape = cute.product_each(t.shape) + rank = len(reduced_shape) + if const_expr(reduced_shape[1] > 1): + assert rank >= 2, "Need rank >= 2 for t in split_wg" + t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg)) + coord = (None, (None, wg_idx)) + (None,) * (rank - 2) + else: + assert rank >= 3, "Need rank >= 3 for t in split_wg" + if const_expr(rank == 3): + t = cute.logical_divide( + t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) + ) + coord = ( + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 3) + else: + t = cute.logical_divide( + t, + ( + reduced_shape[0], + reduced_shape[1], + reduced_shape[2], + reduced_shape[3] // num_wg, + ), + ) + coord = ( + None, + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 4) + return t[coord] + + @cute.jit + def compute_loop( + self, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, + tStS: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + sdS: cute.Tensor, + tdPtdP: cute.Tensor, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, + pipeline_S_P: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dKV: PipelineAsync, + pipeline_dP: PipelineAsync, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + sdV: Optional[cute.Tensor], + sdK: Optional[cute.Tensor], + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + tiled_copy_r2s_dKV: Optional[cute.TiledCopy], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + ): + sLSE_2D = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.Q_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + sdPsum_2D = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.dO_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + # if const_expr(self.SdP_swapAB): + if const_expr(True): + sLSE_2D = utils.transpose_view(sLSE_2D) + sdPsum_2D = utils.transpose_view(sdPsum_2D) + + # tix: [128...384] 8 warps + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0]) + dp_idx = tidx % 128 + num_wg = len(self.compute_warp_ids) // 4 # 2 + # wg_idx: + # 0: [256...384] + # 1: [128...256] + + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128 + # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) + # tP overlap with tS + tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong + tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) + tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + # tdS overlap with tdP + tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2])) + tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) + + # tmem -> rmem + thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx) + tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1) + tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP) + tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) + t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1) + # ((32, 1), 2, 1, 1, STAGE) + tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D)) + tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D)) + # rmem -> tmem + thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) + tScP_r2t = thr_copy_r2t.partition_S(tScP) + tStP_r2t = thr_copy_r2t.partition_D(tStP) + tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS) + tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS) + # rmem -> smem + # This part is a bit iffy, we might be making a lot of assumptions here + copy_atom_r2s = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r + ) + thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx) + # We assume the swizzle (i.e. layout.inner) stays the same + sdS_layout = sm100_utils_basic.make_smem_layout_epi( + self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1 + ).outer # ((8,16), (64,2), (1, 1)) + sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2)) + # Need to group into 1 mode to be compatible w thr_copy_r2s + sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,)) + sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout) + tRS_sdS = thr_copy_r2s.partition_D(sdS_epi) + + consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + # consumer_phase_S_P_dP = Int32(0) + producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 + cutlass.pipeline.PipelineUserType.Producer, 1 + ) + consumer_state_dKV = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 2 + ) + consumer_state_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + # consumer_state_dPsum = cutlass.pipeline.make_pipeline_state( + consumer_state_dPsum = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + # TODO: condition mask_seqlen + mask_fn = partial( + mask.apply_mask_sm100_transposed, + tScS_t2r=tScS_t2r, + t0ScS_t2r=t0ScS_t2r, + n_block=n_block, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + ) + + # prefetch_LSE = not self.is_causal + prefetch_LSE = False + + # Mainloop + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + # Prefetch 1 stage of LSE + pipeline_LSE.consumer_wait(consumer_state_LSE) + tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) + if const_expr(prefetch_LSE and not self.shuffle_LSE): + cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r) + + pipeline_S_P.consumer_wait(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) + #### TMEM->RMEM (Load S from TMEM) + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) + cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) + + #### APPLY MASK + mask_fn(tSrS_t2r, m_block=m_block) + + num_stages = cute.size(tScS_t2r, mode=[1]) + + # --------------------------------------------- + #### P = exp(S - LSE) + # --------------------------------------------- + lane_idx = cute.arch.lane_idx() + tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64 + tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) + for stage in cutlass.range_constexpr(num_stages): + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index] + if const_expr(not self.shuffle_LSE): + if const_expr(stage > 0 or not prefetch_LSE): + cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r) + tSrLSE = tSrLSE_s2r + else: + tSrLSE = tSsLSE_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2): + if const_expr(not self.shuffle_LSE): + lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1]) + else: + lse_pair = ( + utils.shuffle_sync(tSrLSE, offset=2 * v), + utils.shuffle_sync(tSrLSE, offset=2 * v + 1), + ) + tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2( + ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])), + (softmax_scale_log2, softmax_scale_log2), + (-lse_pair[0], -lse_pair[1]), + ) + tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) + tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) + utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0]) + if const_expr(stage == 0): + cute.arch.fence_view_async_tmem_load() + # Without this barrier, we could have 1 warp writing to P in tmem while + # another warp is still reading S from tmem. + self.compute_sync_barrier.arrive_and_wait() + cute.copy( + thr_copy_r2t, + tSrP_r2t_f32[None, stage, None, None], + tStP_r2t[None, stage, None, None], + ) + + cute.arch.fence_view_async_tmem_store() + self.compute_sync_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + pipeline_S_P.consumer_release(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) + pipeline_LSE.consumer_release(consumer_state_LSE) + # consumer_state_S_P_dP.advance() + consumer_state_LSE.advance() + + # --------------------------------------------- + # dS.T = P.T * (dP.T - D) + # --------------------------------------------- + pipeline_dPsum.consumer_wait(consumer_state_dPsum) + + pipeline_dP.consumer_wait(consumer_state_S_P_dP) + # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) + consumer_state_S_P_dP.advance() + # consumer_phase_S_P_dP ^= 1 + + ##### dS.T = P.T * (dP.T - Psum) + for stage in cutlass.range_constexpr(num_stages): + tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) + cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) + cute.arch.fence_view_async_tmem_load() + self.compute_sync_barrier.arrive_and_wait() + tdPrdP_cur = tdPrdP_t2r[None, 0, 0] + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] + if const_expr(not self.shuffle_dPsum): + tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32) + cute.autovec_copy(tSsdPsum_cur, tSrdPsum) + else: + tSrdPsum = tSsdPsum_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2): + if const_expr(not self.shuffle_dPsum): + dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1]) + else: + dPsum_pair = ( + utils.shuffle_sync(tSrdPsum, offset=2 * v), + utils.shuffle_sync(tSrdPsum, offset=2 * v + 1), + ) + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2( + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair + ) + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2( + (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), + ) + tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) + utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt) + if const_expr(stage == 0): + pipeline_dS.producer_acquire(producer_state_dS) + cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) + if const_expr(not self.use_smem_dS_for_mma_dK): + tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) + cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + + if const_expr(not self.use_smem_dS_for_mma_dK): + cute.arch.fence_view_async_tmem_store() + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + self.compute_sync_barrier.arrive_and_wait() + + # with cute.arch.elect_one(): + # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive + # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) + pipeline_dPsum.consumer_release(consumer_state_dPsum) + consumer_state_dPsum.advance() + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() + + # Epilogue + if const_expr(not self.is_local) or m_block_min < m_block_max: + if const_expr(not self.use_tma_store): + consumer_state_dKV = self.epilogue_dKV( + dp_idx, + warp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + thr_mma_dK, + tdVtdV, + tdKtdK, + mdV, + mdK, + pipeline_dKV, + consumer_state_dKV, + softmax_scale, + ) + else: + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) + #### STORE dV + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + tdVtdV, + mdV_tma_tensor, + sdV, + tma_atom_dV, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + None, # Don't scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdV_semaphore, + ) + #### STORE dK + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dK, + tdKtdK, + mdK_tma_tensor, + sdK, + tma_atom_dK, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdK_semaphore, + ) + if const_expr(self.qhead_per_kvhead == 1 and self.is_local): + if m_block_min >= m_block_max: + # if tidx == 0: + # cute.printf("m_block_min = {}, m_block_max = {}", m_block_min, m_block_max) + # like other epis, currently assumes hdim == hdimv + gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( + self.dk_dtype, + self.tile_hdim, + 128, # num_threads + ) + gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) + tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV) + assert tdKgdK.shape[2] == 1 + assert tdVgdV.shape[2] == 1 + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV) + zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) + zero.fill(0.0) + if tidx < 128: + for i in cutlass.range_constexpr(tdKgdK.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0]) + else: + for i in cutlass.range_constexpr(tdVgdV.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0]) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def dQacc_reduce( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + thr_mma_dQ: cute.core.ThrMma, + tdQtdQ: cute.Tensor, + pipeline_dQ: PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + mdQ_semaphore: Optional[cute.Tensor], + ): + num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_reduce_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) + is_tma_warp = warp_idx == 0 + # TMEM -> RMEM + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + ) + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) + tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ) + tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, ( + "dQaccum reduce stage mismatch" + ) + + thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d( + self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width + ).get_slice(tidx) + tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum) + + read_flag = const_expr(not self.deterministic) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + dQ_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + dQ_tma_store_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sdQaccum_stage + ) + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / STAGE, STAGE, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) + ) + + if const_expr(self.deterministic): + mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + + delay_semaphore_release = self.is_causal + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) + + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + pipeline_dQ.consumer_wait(dQ_consumer_state) + # TMEM -> RMEM + tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r) + cute.arch.fence_view_async_tmem_load() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dQ.consumer_release(dQ_consumer_state) + dQ_consumer_state.advance() + + gdQaccum_cur = gdQaccum[None, None, m_block] + + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + smem_idx = dQ_tma_store_producer_state.index + tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] + tdQrdQ_r2s = cute.make_tensor( + tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape + ) + cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + # semaphore acquire + if const_expr(self.deterministic and stage == 0): + if const_expr(self.spt): + if const_expr( + self.is_causal or block_info.window_size_right is not None + ): + n_idx_right = ( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q + ) + if const_expr(block_info.window_size_right is not None): + n_idx_right += block_info.window_size_right + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div(n_idx_right, self.tile_n), + ) + else: + n_block_max_for_m_block = n_block_global_max + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value + ) + self.reduce_sync_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, smem_idx].iterator, + gdQaccum_cur[None, stage].iterator, + self.tma_copy_bytes["dQ"] // 1, + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + dQ_tma_store_producer_state.advance() + # Directly add to gmem, much slower + # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) + # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) + # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): + # copy_utils.atomic_add_fp32x4( + # tdQrdQ_r2s[4 * i], + # tdQrdQ_r2s[4 * i + 1], + # tdQrdQ_r2s[4 * i + 2], + # tdQrdQ_r2s[4 * i + 3], + # utils.elem_pointer(tdQgdQ, 4 * i), + # ) + # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): + if m_block > m_block_min: + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1 + ) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(self.deterministic and not delay_semaphore_release): + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) + + if const_expr(not self.is_local) or m_block_min < m_block_max: + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1 + ) + + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def epilogue_dKV( + self, + tidx: Int32, + warp_idx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, + softmax_scale: Float32, + ): + wg_idx = ( + cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + ) // 128 + num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 + + assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) + + # dV + pipeline_dKV.consumer_wait(consumer_state_dKV) + + tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) + thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) + + tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV) + tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) + + cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) + tdVcdV = thr_mma_dV.partition_C(cdV) + tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) + + tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) + tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) + tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) + + cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dv_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tiled_gmem_store_dV = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dV.tiler_mn, + ) + + tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) + for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])): + dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() + tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) + + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + gdV_tile = gdV[None, None, n_block] + + tdVgdV = thr_mma_dV.partition_C(gdV_tile) + tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) + tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) + + cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + + # dK + pipeline_dKV.consumer_wait(consumer_state_dKV) + + tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) + thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) + + tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK) + tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) + + cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) + tdKcdK = thr_mma_dK.partition_C(cdK) + tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) + + tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) + tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) + tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) + + cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=universal_copy_bits, + ) + + tiled_gmem_store_dK = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dK.tiler_mn, + ) + + tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) + + for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])): + dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale + tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) + + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + gdK_tile = gdK[None, None, n_block] + + tdKgdK = thr_mma_dK.partition_C(gdK_tile) + tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) + tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) + + cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV + + @cute.jit + def epilogue_dK_or_dV_tma( + self, + tidx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma: cute.core.ThrMma, + tdKVtdKV: cute.Tensor, + mdKV: cute.Tensor, + sdKV: cute.Tensor, + tma_atom_dKV: cute.CopyAtom, + thr_copy_r2s_dKV: cute.TiledCopy, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, + scale: Optional[Float32], + barrier_id: Int32, + mdKV_semaphore: Optional[cute.Tensor], + ) -> cutlass.pipeline.PipelineState: + # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) + # head_dim = head_dim_v, dk_dtype = dv_dtype + num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) + wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128 + num_wg = num_compute_threads // 128 + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + if const_expr(self.qhead_per_kvhead == 1): + sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 + else: + sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 + + # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) + tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) + + head_idx_kv = head_idx // self.qhead_per_kvhead + if const_expr(self.qhead_per_kvhead == 1): + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) + ) # (tile_n, hdim) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) + gdKV_epi = cute.local_tile( + gdKV, self.sdKV_epi_tile, (0, None) + ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) + else: + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) + ) # (tile_n * hdim) + gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[ + ((None, wg_idx),) + ] # (tile_n * hdim / 2) + gdKV_epi = cute.flat_divide( + gdKV, (self.sdKV_flat_epi_tile,) + ) # (tile_n * hdim / 2 / epi_stage, epi_stage) + + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 + if const_expr(deterministic_KV): + mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] + + if const_expr(self.qhead_per_kvhead == 1): + tdKVsdKV, tdKVgdKV = cpasync.tma_partition( + tma_atom_dKV, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sdKV, 0, 2), + cute.group_modes(gdKV_epi, 0, 2), + ) # (TMA) and (TMA, EPI_STAGE) + assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" + assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" + num_epi_stages = cute.size(tdKVgdKV.shape[1]) + assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong" + else: + num_epi_stages = self.num_epi_stages + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + + read_flag = const_expr(not deterministic_KV) + + pipeline_dKV.consumer_wait(consumer_state_dKV) + + # semaphore acquire + if const_expr(deterministic_KV): + barrier.wait_eq( + mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead + ) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + for epi_stage in cutlass.range_constexpr(num_epi_stages): + # TMEM -> RMEM -- setup + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) + tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV) + tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage] + + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = thr_mma.partition_C(cdKV) + tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) + tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage] + + tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) + + assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, ( + "RMEM<->TMEM fragment size mismatch" + ) + + # TMEM -> RMEM -- copy and fence + cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r) + cute.arch.fence_view_async_tmem_load() + + # RMEM -- scale and convert + if const_expr(scale is not None): + for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True): + tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( + (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) + ) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) + tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) + + # RMEM -> SMEM -- copy, fence and barrier + tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) + cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + # SMEM -> GMEM + if leader_warp: + if const_expr(self.qhead_per_kvhead == 1): + cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage]) + else: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdKV.iterator, + gdKV_epi[None, epi_stage].iterator, + self.tma_copy_bytes["dKacc"], + ) + if const_expr(epi_stage < num_epi_stages - 1): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier_arrive( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) + + # Barrier since all warps need to wait for SMEM to be freed + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(deterministic_KV): + if leader_warp: + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py new file mode 100644 index 00000000000..641adef4846 --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -0,0 +1,1244 @@ +import math +from typing import Callable, Optional, Type +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.cute.arch import ProxyKind, SharedSpace +from cutlass import Float32, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum + +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase +from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd + + +def mma_partition_fragment_AB( + thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool +): + if const_expr(not swap_AB): + return ( + thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None, + thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None, + ) + else: + return ( + thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None, + thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None, + ) + + +class FlashAttentionBackwardSm90: + arch = 90 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + tile_m: int = 64, + tile_n: int = 128, + Q_stage: int = 2, + dO_stage: int = 2, + PdS_stage: int = 2, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 1, + AtomLayoutNdKV: int = 2, + AtomLayoutMdQ: int = 1, + num_threads: int = 384, + V_in_regs: bool = False, + ): + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal + self.is_local = False + self.tile_m = tile_m + self.tile_n = tile_n + self.num_threads = num_threads + self.Q_stage = Q_stage + self.dO_stage = dO_stage + self.PdS_stage = PdS_stage + assert self.dO_stage in [1, self.Q_stage] + assert self.PdS_stage in [1, self.Q_stage] + self.SdP_swapAB = SdP_swapAB + self.dKV_swapAB = dKV_swapAB + self.dQ_swapAB = dQ_swapAB + self.AtomLayoutMSdP = AtomLayoutMSdP + self.AtomLayoutNdKV = AtomLayoutNdKV + self.AtomLayoutMdQ = AtomLayoutMdQ + self.num_mma_warp_groups = (self.num_threads // 128) - 1 + self.mma_dkv_is_rs = ( + AtomLayoutMSdP == 1 + and AtomLayoutNdKV == self.num_mma_warp_groups + and SdP_swapAB + and not dKV_swapAB + ) + self.V_in_regs = V_in_regs + # These are tuned for speed + # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share + # them and then shuffle to get the value whenever we need? This can reduce register + # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) + # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. + # TODO: impl these for hdim 64 + self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 + self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + Q_stage, + num_threads, + V_in_regs=False, + ) -> bool: + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + if (tile_m * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mdO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric], + mdPsum_type: Type[cutlass.Numeric], + mdQaccum_type: Type[cutlass.Numeric], + mdK_type: Type[cutlass.Numeric], + mdV_type: Type[cutlass.Numeric], + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mdPsum_type not in [Float32]): + raise TypeError("dPsum tensor must be Float32") + if const_expr(mdQaccum_type not in [Float32]): + raise TypeError("dQaccum tensor must be Float32") + if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not (mdK_type == mdV_type == mQ_type)): + raise TypeError("mdK and mdV tensors must have the same data type as mQ") + else: + if const_expr(not (mdK_type == mdV_type == Float32)): + raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ + sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) + for shape, stage in [ + ((self.tile_m, self.tile_hdim), self.Q_stage), + ((self.tile_n, self.tile_hdim), None), + ((self.tile_n, self.tile_hdimv), None), + ((self.tile_m, self.tile_hdimv), self.dO_stage), + ((self.tile_m, self.tile_n), self.PdS_stage), + ] + ] + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + ) + # dQaccum R->S + self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + # thr_layout + cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout(128 // Float32.width), # val_layout + ) + + def _get_tiled_mma(self): + # S = Q @ K.T, dP = dO @ V.T + atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP) + tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) + tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1]) + + (1,), + tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1], + ) + # dV = P.T @ dO, dK = dS.T @ Q + atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) + tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) + tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) + tiled_mma_dK, tiled_mma_dV = [ + sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.MN + if not self.mma_dkv_is_rs + else warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) + + (1,), + tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], + a_source=warpgroup.OperandSource.RMEM + if self.mma_dkv_is_rs + else warpgroup.OperandSource.SMEM, + ) + for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) + ] + # dQ = dS @ K + atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) + tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + ) + return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + + def _get_shared_storage_cls(self): + sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 1024 + + sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ + cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] + for (layout, type, alignment) in [ + (self.sQ_layout, self.dtype, sQ_alignment), + (self.sK_layout, self.dtype, sK_alignment), + (self.sV_layout, self.dtype, sV_alighment), + (self.sdO_layout, self.dtype, sdO_alignment), + (self.sdQaccum_layout, Float32, sdQaccum_alignment), + ] + ] + + cosize_sdS = cute.cosize(self.sPdS_layout) + cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0 + sLSE_struct = cute.struct.Align[ + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128 + ] + sdPsum_struct = cute.struct.Align[ + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128 + ] + + @cute.struct + class SharedStorageQKV: + mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2] + mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2] + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sQ: sQ_struct + sV: sV_struct + sK: sK_struct + sdO: sdO_struct + sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024] + sdQaccum: sdQaccum_struct + + return SharedStorageQKV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + ): + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ) + ) + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ] + + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdK, mdV, mdO = [ + utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdK, mdV, mdO) + ] + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + mLSE, mdPsum, mdQaccum = [ + utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) + ] + + tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() + + self.num_mma_threads = tiled_mma_SdP.size + assert self.num_mma_threads + 128 == self.num_threads + + self.num_threads_per_warp_group = 128 + self.num_producer_threads = 32 + + self.num_mma_regs = 240 + self.num_producer_regs = 24 + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = ( + self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups + ) + + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mQ, + cute.select(self.sQ_layout, mode=[0, 1]), + (self.tile_m, self.tile_hdim), + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + ) + tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mdO, + cute.select(self.sdO_layout, mode=[0, 1]), + (self.tile_m, self.tile_hdimv), + ) + tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + ) + tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + ) + + TileScheduler = SingleTileScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), + cute.size(mK.shape[2]), + cute.size(mK.shape[3]), + 1, # num_splits + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=None, + mSeqUsedQ=None, + qhead_per_kvhead_packgqa=1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=False, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + tma_tensor_dO, + tma_tensor_dK, + tma_tensor_dV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + tma_atom_dK, + tma_atom_dV, + mLSE, + mdPsum, + mdQaccum, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sPdS_layout, + self.sdO_layout, + self.sdQaccum_layout, + self.r2s_tiled_copy_dQaccum, + tiled_mma_SdP, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, + softmax_scale_log2, + softmax_scale, + tile_sched_params, + TileScheduler, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + softmax_scale_log2, + softmax_scale, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # prefetch TMA descriptors + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group + ) + pipeline_Q = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_Q.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], + init_wait=False, + ) + pipeline_dO = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_dO.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"], + init_wait=True, + ) + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sP = None + if const_expr(not self.mma_dkv_is_rs): + sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sLSE = storage.sLSE.get_tensor( + cute.make_layout( + (self.tile_m, self.Q_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + ) + sdPsum = storage.sdPsum.get_tensor( + cute.make_layout( + (self.tile_m, self.dO_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + ) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + None, + None, + qhead_per_kvhead_packgqa=1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=None, + window_size_right=None, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + if warp_idx == 0: + self.load( + mQ, + mK, + mV, + mdO, + mLSE, + mdPsum, + sQ, + sK, + sV, + sdO, + sLSE, + sdPsum, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + pipeline_Q, + pipeline_dO, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + if warp_idx == 1: + for warp_group_idx in cutlass.range(self.num_mma_warp_groups): + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + self.dQaccum_store(mdQaccum, sdQaccum, block_info, TileSchedulerCls, SeqlenInfoCls) + else: + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_SdP, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, + mdK, + mdV, + mdQaccum, + sQ, + sK, + sV, + sdO, + sP, + sdS, + sLSE, + sdPsum, + sdQaccum, + pipeline_Q, + pipeline_dO, + tidx, + tma_atom_dK, + tma_atom_dV, + r2s_tiled_copy_dQaccum, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + + if warp_idx_in_wg == 0: + producer_state_Q = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mK_cur = mK[None, None, head_idx, batch_idx] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + mV_cur = mV[None, None, head_idx, batch_idx] + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + + mQ_cur = mQ[None, None, head_idx, batch_idx] + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) + mdO_cur = mdO[None, None, head_idx, batch_idx] + gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) + mLSE_cur = mLSE[None, head_idx, batch_idx] + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + mdPsum_cur = mdPsum[None, head_idx, batch_idx] + gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) + + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True + ) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True + ) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ + ) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, 0, cute.make_layout(1), gdO, sdO + ) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) + load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) + load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q) + load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) + load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) + + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + # First iteration: load K together w Q & LSE, then V together w dO & dPsum + m_block = m_block_min + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + load_Q(m_block, producer_state=producer_state_Q) + # cp.async.bulk is using ptx, so we need to elect one thread to do it + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire( + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + # Subsequent iterations: load Q & LSE, then dO & dPsum + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + # cp.async.bulk is using ptx, so we need to elect one thread to do it + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma( + self, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + mdK: cute.Tensor, + mdV: cute.Tensor, + mdQaccum: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sP: Optional[cute.Tensor], + sdS: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + sdQaccum: cute.Tensor, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + tidx: Int32, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + softmax_scale_log2: Float32, + softmax_scale: Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) + wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) + # S = Q @ K.T + tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB) + # dP = dO @ V.T + tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB) + # dV += P.T @ dO + sPt = utils.transpose_view(sP) if sP is not None else None + sdOt = utils.transpose_view(sdO) + tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB) + # dK += dS.T @ Q + sdSt = utils.transpose_view(sdS) + sQt = utils.transpose_view(sQ) + tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB) + # dQ = dS @ K + sKt = utils.transpose_view(sK) + tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB) + + # Smem copy atom tiling + smem_copy_atom_PdS = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.SdP_swapAB + ) + smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( + tidx + ) + tPsP = None + if const_expr(sP is not None): + tPsP = smem_thr_copy_PdS.partition_D(sP if const_expr(not self.SdP_swapAB) else sPt) + tdSsdS = smem_thr_copy_PdS.partition_D(sdS if const_expr(not self.SdP_swapAB) else sdSt) + + sLSE_mma = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.Q_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + sdPsum_mma = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.dO_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + if const_expr(self.SdP_swapAB): + sLSE_mma = utils.transpose_view(sLSE_mma) + sdPsum_mma = utils.transpose_view(sdPsum_mma) + LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None) + tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice] + tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice] + + smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + + dV_shape = (self.tile_n, self.tile_hdimv) + acc_dV = cute.make_fragment( + tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]), + Float32, + ) + dK_shape = (self.tile_n, self.tile_hdim) + acc_dK = cute.make_fragment( + tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]), + Float32, + ) + + mma_qk_fn = partial( + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tSrQ, + tSrK, + swap_AB=self.SdP_swapAB, + ) + mma_dov_fn = partial( + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tdPrdO, + tdPrV, + swap_AB=self.SdP_swapAB, + ) + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn = partial( + gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB + ) + mma_dsq_fn = partial( + gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB + ) + else: + assert not self.dKV_swapAB + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) + mma_dsk_fn = partial( + gemm_zero_init, + tiled_mma_dQ, + (self.tile_m, self.tile_hdim), + tdQrdS, + tdQrKt, + swap_AB=self.dQ_swapAB, + ) + + mma_one_m_block_all = partial( + self.mma_one_m_block, + warp_group_idx=warp_group_idx, + mma_qk_fn=mma_qk_fn, + mma_dov_fn=mma_dov_fn, + mma_pdo_fn=mma_pdo_fn, + mma_dsq_fn=mma_dsq_fn, + mma_dsk_fn=mma_dsk_fn, + pipeline_Q=pipeline_Q, + pipeline_dO=pipeline_dO, + tLSEsLSE=tLSEsLSE, + tLSEsdPsum=tLSEsdPsum, + tPsP=tPsP, + tdSsdS=tdSsdS, + tdQsdQaccum=tdQsdQaccum, + smem_thr_copy_PdS=smem_thr_copy_PdS, + smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, + softmax_scale_log2=softmax_scale_log2, + # acc_dV=acc_dV, + # acc_dK=acc_dK, + ) + + consumer_state_Q = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + consumer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, + batch_idx=None, + head_idx=None, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + ) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) + dKV_accumulate = False + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + consumer_state_Q, consumer_state_dO = mma_one_m_block_all( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn, + dKV_accumulate=dKV_accumulate, + ) + dKV_accumulate = True + + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) + # scale dK + acc_dK.store(acc_dK.load() * softmax_scale) + self.epilogue_dKV( + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, + seqlen, + tma_atom_dK, + tma_atom_dV, + tiled_mma_dK, + tiled_mma_dV, + tidx, + n_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma_one_m_block( + self, + m_block: Int32, + consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + warp_group_idx: Int32, + mma_qk_fn: Callable, + mma_dov_fn: Callable, + mma_pdo_fn: Callable, + mma_dsq_fn: Callable, + mma_dsk_fn: Callable, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + tLSEsLSE: cute.Tensor, + tLSEsdPsum: cute.Tensor, + tPsP: Optional[cute.Tensor], + tdSsdS: Optional[cute.Tensor], + tdQsdQaccum: cute.Tensor, + smem_thr_copy_PdS: cute.TiledCopy, + smem_thr_copy_dQaccum: cute.TiledCopy, + softmax_scale_log2: Float32, + mask_fn: Optional[Callable] = None, + # acc_dV, + # acc_dK, + dKV_accumulate: Boolean = True, + ): + consumer_state_dO_cur = ( + consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q + ) + smem_idx_Q = consumer_state_Q.index + smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 + smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0 + # (1) [GEMM 1] S = Q @ K^T + pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) + acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) + tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) + # (2) [GEMM 2] dP = dO @ V.T + pipeline_dO.consumer_wait( + consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) + ) + acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) + # (3) [Pointwise 1] P = exp(S - LSE) + if cutlass.const_expr(mask_fn is not None): + mask_fn(acc_S, m_block=m_block) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) + # if cute.arch.thread_idx()[0] == 256: cute.print_tensor(acc_S_mn) + for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): + acc_S_mn[r, c] = cute.math.exp2( + acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True + ) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) + tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) + + # Convert P from f32 -> f16 + tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype) + # R2S for P + if const_expr(not self.mma_dkv_is_rs): + # sync to ensure P has already been used in the previous iteration before overwriting + if const_expr(self.PdS_stage == 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + tPrP = smem_thr_copy_PdS.retile(tdVrP) + cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS]) + + # (4) [Pointwise 2] dS = P*(dP-dPsum) + warpgroup.wait_group(0) + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) + for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): + acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) + # Convert dS from f32 -> f16 + tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) + + # If there's double buffering on dS, we don't need to sync here. + # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. + # But because both WGs have to sync at the end of the loop and double buffering, + # this race condition is not possible. + # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and + # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. + if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + + # R2S for dS + tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) + cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS]) + + # (5) [GEMM 3] dV += P.T @ dO + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1 + ) + else: + mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) + + # smem fence to make sure sdS is written before it's read by WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + # (6) [GEMM 4] dQ = dS @ K + acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) + pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done + + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) + + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + warpgroup.wait_group(0) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) + pipeline_Q.consumer_release(consumer_state_Q) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) + + consumer_state_Q.advance() + consumer_state_dO.advance() + return consumer_state_Q, consumer_state_dO + + @cute.jit + def epilogue_dKV( + self, + acc_dV: cute.Tensor, + mdV: cute.Tensor, + sV: cute.Tensor, + acc_dK: cute.Tensor, + mdK: cute.Tensor, + sK: cute.Tensor, + seqlen: SeqlenInfoQK, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tidx: Int32, + n_block: Int32, + head_idx: Int32, + batch_idx: Int32, + ): + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) + rdK = utils.cvt_f16(acc_dK, self.dtype) + + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + smem_copy_atom_dKV = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=self.dKV_swapAB, num_matrices=4), + self.dtype, + ) + smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice(tidx) + smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice(tidx) + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + store_dK, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True + ) + store_dV, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True + ) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # rmem -> smem + taccdVrdV = smem_thr_copy_dV.retile(rdV) + sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV) # reuse sV SMEM + taccdVsdV = smem_thr_copy_dV.partition_D(sdV) + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + if warp_idx == 4: + store_dV() + taccdKrdK = smem_thr_copy_dK.retile(rdK) + sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK) # reuse sK SMEM + taccdKsdK = smem_thr_copy_dK.partition_D(sdK) # reuse sK SMEM + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + # smem -> gmem + if warp_idx == 4: + store_dK() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + + @cute.jit + def dQaccum_store( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + block_info: BlockInfo, + TileSchedulerCls: cutlass.Constexpr[Callable], + SeqlenInfoCls: cutlass.Constexpr[Callable], + ): + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / WG, WG, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) + ) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block].iterator, + self.tma_copy_bytes["dQ"], + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.cp_async_bulk_wait_group( + self.num_mma_warp_groups - 1 - warp_group_idx, read=True + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d8ddd1ae443..57874f6559f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -7,31 +7,47 @@ import math from types import SimpleNamespace -from typing import Type, Callable, Optional +from typing import Type, Callable, Optional, List from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute +from cutlass import Constexpr, Float32, Int32, const_expr, Boolean from cutlass.cute.nvgpu import cpasync, warp, warpgroup -import cutlass.utils.ampere_helpers as sm80_utils_basic +from cutlass.cute.arch import ProxyKind, SharedSpace +import cutlass.utils as utils_basic +from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils +from flash_attn.cute import copy_utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) +from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardBase: - arch: int = 80 def __init__( @@ -41,13 +57,16 @@ def __init__( head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, - has_softcap: bool = False, + is_local: bool = False, pack_gqa: bool = True, - m_block_size: int = 128, - n_block_size: int = 128, + tile_m: int = 128, + tile_n: int = 128, num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, + score_mod: Optional[cutlass.Constexpr] = None, + mask_mod: Optional[cutlass.Constexpr] = None, + has_aux_tensors: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -56,38 +75,56 @@ def __init__( :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int - :param n_block_size: n block size - :type n_block_size: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal + :param score_mod: A callable that takes the attention scores and applies a modification. + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` + :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v - self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication - self.check_hdim_oob = head_dim != self.head_dim_padded - self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal - self.has_softcap = has_softcap + self.is_local = is_local self.pack_gqa = pack_gqa - self.m_block_size = m_block_size - self.n_block_size = n_block_size + self.tile_m = tile_m + self.tile_n = tile_n self.num_threads = num_threads self.num_stages = num_stages self.Q_in_regs = Q_in_regs + self.score_mod = score_mod + self.mask_mod = mask_mod + self.qk_acc_dtype = Float32 + if const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 @staticmethod def can_implement( - dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages, num_threads, is_causal, - Q_in_regs=False + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, ) -> bool: """Check if the kernel can be implemented with the given parameters. @@ -95,10 +132,10 @@ def can_implement( :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int - :param n_block_size: n block size - :type n_block_size: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal @@ -113,23 +150,25 @@ def can_implement( return False if head_dim_v % 8 != 0: return False - if n_block_size % 16 != 0: + if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False # Check if block size setting is out of shared memory capacity # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size - smem_usage_Q = m_block_size * head_dim * 2 - smem_usage_K = n_block_size * head_dim * num_stages * 2 - smem_usage_V = n_block_size * head_dim_v * num_stages * 2 - smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False # Check if twice the block size is divisible by the number of threads - if (m_block_size * 2) % num_threads != 0: + if (tile_m * 2) % num_threads != 0: return False return True @@ -146,19 +185,19 @@ def _check_type( mSeqUsedK_type: Type[cutlass.Numeric] | None, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mLSE_type not in [None, cutlass.Float32]): + if const_expr(mLSE_type not in [None, Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensQ_type not in [None, Int32]): raise TypeError("cu_seqlens_q tensor must be Int32") - if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensK_type not in [None, Int32]): raise TypeError("cu_seqlens_k tensor must be Int32") - if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedQ_type not in [None, Int32]): raise TypeError("seqused_q tensor must be Int32") - if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedK_type not in [None, Int32]): raise TypeError("seqused_k tensor must be Int32") assert mQ_type == self.dtype @@ -166,22 +205,34 @@ def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: Q/K/V # /////////////////////////////////////////////////////////////////////////////// - sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = self._get_smem_layout_atom() + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( + self._get_smem_layout_atom() + ) self.sQ_layout = cute.tile_to_shape( - sQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1), + sQ_layout_atom, + (self.tile_m, self.tile_hdim), + (0, 1), ) self.sK_layout = cute.tile_to_shape( - sK_layout_atom, (self.n_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2), + sK_layout_atom, + (self.tile_n, self.tile_hdim, self.num_stages), + (0, 1, 2), ) self.sV_layout = cute.tile_to_shape( - sV_layout_atom, (self.n_block_size, self.head_dim_v_padded, self.num_stages), (0, 1, 2), + sV_layout_atom, + (self.tile_n, self.tile_hdimv, self.num_stages), + (0, 1, 2), ) self.sO_layout = cute.tile_to_shape( - sO_layout_atom, (self.m_block_size, self.head_dim_v_padded), (0, 1), + sO_layout_atom, + (self.tile_m, self.tile_hdimv), + (0, 1), ) - if cutlass.const_expr(sP_layout_atom is not None): + if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( - sP_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), + sP_layout_atom, + (self.tile_m, self.tile_n), + (0, 1), ) else: self.sP_layout = None @@ -200,31 +251,41 @@ def _setup_attributes(self): ) # atom_universal_copy: universal copy atom for O store atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, ) # tQ_layout and tK_layout: thread layout for QK load tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems - assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" - assert self.num_producer_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) tQ_layout = cute.make_ordered_layout( - (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) tK_layout = cute.make_ordered_layout( - (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we load Q - assert self.m_block_size % tQ_layout.shape[0] == 0 + assert self.tile_m % tQ_layout.shape[0] == 0 tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems tV_layout = cute.make_ordered_layout( - (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # TODO: need a different layout for O if O dtype is not the same as V dtype # tO_layout: thread layout for O store tO_layout = cute.make_ordered_layout( - (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O - assert self.m_block_size % tO_layout.shape[0] == 0 + assert self.tile_m % tO_layout.shape[0] == 0 # Value layouts for copies vQKV_layout = cute.make_layout((1, async_copy_elems)) @@ -253,8 +314,7 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, + softmax_scale: Float32, stream: cuda.CUstream, ): """Configures and launches the flash attention kernel. @@ -272,41 +332,45 @@ def epilogue( mO: cute.Tensor, mLSE: Optional[cute.Tensor], sO: cute.Tensor, - seqlen: SeqlenInfo, + seqlen: SeqlenInfoQK, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], tiled_mma: cute.TiledMma, - tidx: cutlass.Int32, - m_block: cutlass.Int32, - head_idx: cutlass.Int32, - batch_idx: cutlass.Int32, - is_varlen: cutlass.Constexpr[bool] = False, + tidx: Int32, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, ): # store acc_O rO = cute.make_fragment_like(acc_O, self.dtype) rO.store(acc_O.load().to(self.dtype)) # Make sure all threads have finished reading V - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_O = utils.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) # copy acc O from rmem to smem with the smem copy atom cute.copy(smem_copy_atom_O, taccOrO, taccOsO) - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) + cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead + ) # Write LSE from rmem -> gmem - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not is_varlen): + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) - if cutlass.const_expr(not self.pack_gqa): - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + if const_expr(not self.pack_gqa): + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( - gLSE.layout, cute.make_layout((self.head_dim_v_padded,), stride=(0,)) + gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,)) ) gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) thr_mma = tiled_mma.get_slice(tidx) @@ -317,58 +381,72 @@ def epilogue( # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): - if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - taccOcO[0][0]: + if ( + t0accOcO[m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] + ): taccOgLSE[m, 0] = lse[m] else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - if cutlass.const_expr(not is_varlen): + if const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) # thr_mma = tiled_mma.get_slice(tidx) # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) # sync to make sure all smem stores are done - if cutlass.const_expr(self.use_tma_O): + if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - utils.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) - tOsO, tOgO = cpasync.tma_partition( - tma_atom_O, - 0, - cute.make_layout(1), - cute.group_modes(sO, 0, 2), - cute.group_modes(gO, 0, 2), + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) - cute.copy(tma_atom_O, tOsO, tOgO) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + store_O() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) else: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) tOrO = cute.make_fragment_like(tOsO, self.dtype) # load acc O from smem to rmem for wider vectorization cute.autovec_copy(tOsO, tOrO) - if cutlass.const_expr(not self.pack_gqa): - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, ) else: pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @@ -383,24 +461,24 @@ def load_Q( gmem_thr_copy: cute.TiledCopy, gQ: cute.Tensor, sQ: cute.Tensor, - block: cutlass.Int32, - seqlen: cutlass.Int32, - headdim: cutlass.Int32, + block: Int32, + seqlen: Int32, + headdim: Int32, ): tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) - cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tQcQ = gmem_thr_copy.partition_S(cQ) t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) tQpQ = utils.predicate_k(tQcQ, limit=headdim) - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. - if t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]: + if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]: cute.copy( gmem_thr_copy, tQgQ[None, m, None], tQsQ[None, m, None], - pred=tQpQ[None, m, None] if self.check_hdim_oob else None, + pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @@ -413,39 +491,41 @@ def load_K( tKcK: cute.Tensor, t0KcK: cute.Tensor, tKpK: cute.Tensor, - block: cutlass.Int32, - smem_pipe_write: cutlass.Int32, - seqlen: cutlass.Int32, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load K? - is_even_n_smem_k = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 - if cutlass.const_expr(need_predicates or not is_even_n_smem_k): + is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_k): # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. - if cutlass.const_expr(is_even_n_smem_k): - seqlen_limit = seqlen - block * self.n_block_size + if const_expr(is_even_n_smem_k): + seqlen_limit = seqlen - block * self.tile_n else: - if cutlass.const_expr(not need_predicates): - seqlen_limit = self.n_block_size + if const_expr(not need_predicates): + seqlen_limit = self.tile_n else: - seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) + seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n) seqlen_limit -= tKcK[0][0] - for n in range(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: cute.copy( gmem_tiled_copy, tKgK[None, n, None, block], - tKsK[None, n, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tKpK[None, n, None] if self.check_hdim_oob else None, + tKsK[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], + pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. else: cute.copy( gmem_tiled_copy, tKgK[None, None, None, block], - tKsK[None, None, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tKpK if self.check_hdim_oob else None, + tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK if const_expr(self.check_hdim_oob) else None, ) @cute.jit @@ -457,58 +537,65 @@ def load_V( tVcV: cute.Tensor, t0VcV: cute.Tensor, tVpV: cute.Tensor, - block: cutlass.Int32, - smem_pipe_write: cutlass.Int32, - seqlen: cutlass.Int32, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load V? - is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 - if cutlass.const_expr(need_predicates or not is_even_n_smem_v): - for n in range(cute.size(tVsV.shape[1])): + is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_v): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked - if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: - predicate = tVpV[None, n, None] if self.check_hdim_v_oob else None - if cutlass.const_expr(need_predicates): - seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] + if ( + is_even_n_smem_v + or n < cute.size(tVsV.shape[1]) - 1 + or tVcV[0, n, 0][0] < self.tile_n + ): + predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None + if const_expr(need_predicates): + seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if self.check_hdim_v_oob else True) and predicate_n + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = ( + tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True + ) and predicate_n cute.copy( gmem_tiled_copy, tVgV[None, n, None, block], - tVsV[None, n, None, smem_pipe_write if self.num_stages > 1 else 0], + tVsV[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], pred=predicate, ) else: cute.copy( gmem_tiled_copy, tVgV[None, None, None, block], - tVsV[None, None, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tVpV if self.check_hdim_v_oob else None, + tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tVpV if const_expr(self.check_hdim_v_oob) else None, ) class FlashAttentionForwardSm80(FlashAttentionForwardBase): - def _get_smem_layout_atom(self): - sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) sK_layout_atom = sQ_layout_atom - sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv) sO_layout_atom = sV_layout_atom sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom def _get_tiled_mma(self): tiled_mma_qk = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) tiled_mma_pv = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) @@ -519,7 +606,7 @@ def _get_shared_storage_cls(self): cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] @cute.struct @@ -533,7 +620,7 @@ class SharedStorageSharedQV: sQ: sQV_struct sK: sK_struct - return SharedStorageQKV if cutlass.const_expr(not self.Q_in_regs) else SharedStorageSharedQV + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( @@ -543,16 +630,22 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, stream: cuda.CUstream, + softmax_scale: Optional[Float32] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + learnable_sink: Optional[cute.Tensor] = None, + aux_tensors=None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ - self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))) + assert learnable_sink is None, "Learnable sink is not supported in this kernel" + self._check_type( + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size self.num_producer_threads = self.num_threads @@ -562,26 +655,47 @@ def __call__( self.use_tma_O = self.arch >= 90 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) for t in (mQ, mK, mV, mO)] + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) + for t in (mQ, mK, mV, mO) + ] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( - cute.ceil_div(mQ.shape[0], self.m_block_size), + cute.ceil_div(mQ.shape[0], self.tile_m), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]), ) - # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - # Right after this, we multiply by log2(e) before applying exp2. - # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) - # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(not self.has_softcap): - softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = cutlass.Float32(0.0) + if const_expr(self.score_mod is None): + softmax_scale_log2 = Float32(softmax_scale * LOG2_E) + softmax_scale = None else: - softmax_scale_log2 = softcap * LOG2_E - softcap_val = softmax_scale / softcap + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = Float32(LOG2_E) + softmax_scale = Float32(softmax_scale) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.kernel( mQ, mK, @@ -589,7 +703,9 @@ def __call__( mO, mLSE, softmax_scale_log2, - softcap_val, + softmax_scale, + window_size_left, + window_size_right, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -602,6 +718,8 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, + aux_tensors, + fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -617,8 +735,10 @@ def kernel( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale_log2: cutlass.Float32, - softcap_val: cutlass.Float32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -631,16 +751,24 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, + aux_tensors=None, + fastdiv_mods=None, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() m_block, num_head, batch_size = cute.arch.block_idx() block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, - self.qhead_per_kvhead if self.pack_gqa else 1, + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfo(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) + seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: @@ -651,9 +779,9 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - blkQ_shape = (self.m_block_size, self.head_dim_padded) - blkK_shape = (self.n_block_size, self.head_dim_padded) - blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkQ_shape = (self.tile_m, self.tile_hdim) + blkK_shape = (self.tile_n, self.tile_hdim) + blkV_shape = (self.tile_n, self.tile_hdimv) gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) num_head_kv = num_head // self.qhead_per_kvhead gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) @@ -666,11 +794,11 @@ def kernel( storage = smem.allocate(SharedStorage) sQ = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout) - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout) else: sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) - # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) @@ -688,18 +816,20 @@ def kernel( tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) - acc_shape_O = thr_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) + acc_O = cute.make_fragment(acc_shape_O, Float32) acc_O.fill(0.0) # /////////////////////////////////////////////////////////////////////////////// # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_QK = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, ) smem_copy_atom_V = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self.dtype, ) smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) @@ -714,53 +844,70 @@ def kernel( # of tile_shape # /////////////////////////////////////////////////////////////////////////////// # Construct identity layout for KV - cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tKcK = gmem_thr_copy_K.partition_S(cK) t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) - if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + if const_expr(self.tile_hdim == self.tile_hdimv): tVcV = tKcK t0VcV = t0KcK else: - cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) tVcV = gmem_thr_copy_V.partition_S(cV) t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) # Allocate predicate tensors for m and n, here we only allocate the tile of k, and # use "if" on the mn dimension. # This is to reduce register pressure and gets 2-3% performance gain. tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) - if cutlass.const_expr(self.same_hdim_kv): + if const_expr(self.same_hdim_kv): tVpV = tKpK else: tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) softmax.reset() # group parameters for compute_one_n_block mma_params = SimpleNamespace( - thr_mma_qk=thr_mma_qk, thr_mma_pv=thr_mma_pv, - tSrQ=tSrQ, tSrK=tSrK, tOrVt=tOrVt, acc_O=acc_O, + thr_mma_qk=thr_mma_qk, + thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, ) smem_copy_params = SimpleNamespace( smem_thr_copy_Q=smem_thr_copy_Q, smem_thr_copy_K=smem_thr_copy_K, smem_thr_copy_V=smem_thr_copy_V, - tSsQ=tSsQ, tSsK=tSsK, tOsVt=tOsVt, - ) - load_K = partial(self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, - seqlen=seqlen.seqlen_k) - load_V = partial(self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, - seqlen=seqlen.seqlen_k) - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if cutlass.const_expr(self.has_softcap): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) + tSsQ=tSsQ, + tSsK=tSsK, + tOsVt=tOsVt, + ) + load_K = partial( + self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k + ) + load_V = partial( + self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k + ) compute_one_n_block = partial( - self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, load_K=load_K, load_V=load_V, scoremod_premask_fn=scoremod_premask_fn, + self.compute_one_n_block, + mma_params=mma_params, + smem_copy_params=smem_copy_params, + softmax=softmax, + load_K=load_K, + load_V=load_V, + score_mod=self.score_mod, + batch_idx=batch_size, + head_idx=num_head, + m_block=m_block, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) # /////////////////////////////////////////////////////////////////////////////// @@ -773,7 +920,7 @@ def scoremod_premask_fn(acc_S): def preprocess_Q(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) - if cutlass.const_expr(self.Q_in_regs): + if const_expr(self.Q_in_regs): cute.arch.barrier() tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) @@ -781,22 +928,22 @@ def preprocess_Q(): # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and # read from smem_q to registers, then load V. # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. - if cutlass.const_expr(self.Q_in_regs): + if const_expr(self.Q_in_regs): load_K(n_block, smem_pipe_write=0, need_predicates=True) cute.arch.cp_async_commit_group() preprocess_Q() cute.arch.barrier() # Make sure all threads have read smem_q before loading V - for stage in range(self.num_stages): - if cutlass.const_expr(not self.Q_in_regs or stage > 0): + for stage in cutlass.range_constexpr(self.num_stages): + if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: - load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() - if stage < self.num_stages - 1: + if const_expr(stage < self.num_stages - 1): if stage == 0 or n_block - stage >= 0: - load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): preprocess_Q() # /////////////////////////////////////////////////////////////////////////////// @@ -805,41 +952,63 @@ def preprocess_Q(): # Start processing of the first n-block. # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( - self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, - self.qhead_per_kvhead if self.pack_gqa else 1, + self.tile_m, + self.tile_n, + seqlen.seqlen_q, + seqlen.seqlen_k, + window_size_left, + window_size_right, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal + mask.apply_mask, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, ) # First iteration with seqlen masking - smem_pipe_read = cutlass.Int32(0) - smem_pipe_write = cutlass.Int32(self.num_stages - 1) - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + smem_pipe_read = Int32(0) + smem_pipe_write = Int32(self.num_stages - 1) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + is_first_n_block=True, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking - if self.is_causal: + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) - # Currently we can't do loop with negative step - # https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): - n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=False)) + for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block, unroll=1): - compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=False) + for n_tile in cutlass.range(n_block, unroll=1): + compute_one_n_block( + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # TODO: local # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize() @@ -851,114 +1020,168 @@ def preprocess_Q(): # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, - gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + None, + tiled_mma_pv, + tidx, + m_block, + num_head, + batch_size, ) @cute.jit def compute_one_n_block( self, - n_block: cutlass.Int32, - smem_pipe_read: cutlass.Int32, - smem_pipe_write: cutlass.Int32, + n_block: Int32, + smem_pipe_read: Int32, + smem_pipe_write: Int32, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, load_K: Callable, load_V: Callable, - scoremod_premask_fn: Callable, + score_mod: Callable | None, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + aux_tensors=None, + fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): """Compute one n_block of S/O. This function provides different variants for processing the first n block versus subsequent blocks. """ + def sync(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) cute.arch.barrier() - acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) - acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32) + acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) + acc_S = cute.make_fragment(acc_shape_S, Float32) acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S sync() + # need predicates for the first tile def load_V_next(): if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: - load_V(n_block - self.num_stages + 1, smem_pipe_write, - need_predicates=is_first_n_block and self.num_stages == 1) + load_V( + n_block - self.num_stages + 1, + smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1, + ) cute.arch.cp_async_commit_group() + load_V_next() sm80_utils.gemm( - mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, + mma_params.thr_mma_qk, + acc_S, + mma_params.tSrQ, + mma_params.tSrK, smem_copy_params.tSsQ, - smem_copy_params.tSsK[None, None, None, smem_pipe_read if self.num_stages > 1 else 0], - smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, + smem_copy_params.tSsK[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_Q, + smem_copy_params.smem_thr_copy_K, # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) - scoremod_premask_fn(acc_S) + if const_expr(score_mod is not None): + self.apply_score_mod( + mma_params.thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale=softmax.softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + def load_K_next(): if n_block - self.num_stages >= 0: load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) cute.arch.cp_async_commit_group() + # wait for smem tile V for O - if cutlass.const_expr(self.num_stages == 1): + if const_expr(self.num_stages == 1): sync() load_K_next() - if cutlass.const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) softmax.rescale_O(mma_params.acc_O, row_scale) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - if cutlass.const_expr(self.num_stages > 1): + if const_expr(self.num_stages > 1): sync() load_K_next() sm80_utils.gemm_rs( - mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, - smem_copy_params.tOsVt[None, None, None, smem_pipe_read if self.num_stages > 1 else 0], + mma_params.thr_mma_pv, + mma_params.acc_O, + tOrP, + mma_params.tOrVt, + smem_copy_params.tOsVt[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], smem_copy_params.smem_thr_copy_V, # hook_fn=load_K_next, ) - # if cutlass.const_expr(self.num_stages > 1): + # if const_expr(self.num_stages > 1): # load_K_next() class FlashAttentionForwardSm90(FlashAttentionForwardBase): - arch = 90 - def __init__(self, *args, intra_wg_overlap: bool = True, **kwargs): + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + **kwargs, + ): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded - ), - self.dtype + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv ), - self.dtype + self.dtype, ) sO_layout_atom = sV_layout_atom - sP_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size - ), - self.dtype - ) + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n + ), + self.dtype, + ) + else: + sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom def _get_tiled_mma(self): @@ -967,36 +1190,49 @@ def _get_tiled_mma(self): self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, - cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.n_block_size), + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_n), ) tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, - cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.head_dim_v_padded), + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, ) - return tiled_mma_qk, tiled_mma_pv + tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM, + ) + return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs def _get_shared_storage_cls(self): - # If PackGQA, we use cp.async to load Q, so we want sQ to align to 1024 bytes - sQ_alignment = 128 if not self.pack_gqa else 1024 + # If we use cp.async to load Q, we want sQ to align to 1024 bytes + sQ_alignment = 128 if const_expr(self.use_tma_Q) else 1024 sK_alignment = 128 sV_alignment = 128 sQ_struct, sK_struct, sV_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], alignment] for layout, alignment in zip( - (self.sQ_layout, self.sK_layout, self.sV_layout), - (sQ_alignment, sK_alignment, sV_alignment) + (self.sQ_layout, self.sK_layout, self.sV_layout), + (sQ_alignment, sK_alignment, sV_alignment), ) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] - cosize_sP = cute.cosize(self.sP_layout) if self.sP_layout is not None else 0 + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] @@ -1022,121 +1258,260 @@ class SharedStorageSharedQV: sK: sK_struct sP: sP_struct - return SharedStorageQKV if cutlass.const_expr(not self.Q_in_regs) else SharedStorageSharedQV + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mO: cute.Tensor, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mCuSeqlensK: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - mSeqUsedK: Optional[cute.Tensor], - max_seqlen_q: Optional[cutlass.Int32], - softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, + softmax_scale: Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ + self._check_type( - *(t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) ) - QO_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) - for t in (mQ, mO) - ] - KV_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensK is None) else [0, 2, 1] - mK, mV = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) - for t in (mK, mV) + + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) ] - LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) - tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [utils.select(t, KV_layout_transpose) for t in (mK, mV)] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None + + tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) self.num_producer_threads = 32 - self.num_Q_load_threads = self.num_mma_threads # If PackGQA, MMA threads load Q + self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q self.num_epilogue_threads = self.num_mma_threads - self.num_mma_regs = 240 - self.num_producer_regs = 24 - self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if self.intra_wg_overlap else (self.num_mma_warp_groups == 2) - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and not self.pack_gqa + self.num_mma_regs = ( + 256 + if self.num_mma_warp_groups == 1 + else (240 if self.num_mma_warp_groups == 2 else 160) + ) + self.num_producer_regs = ( + 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) + ) + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.use_scheduler_barrier = ( + (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_mma_warp_groups == 2) + ) + self.use_tma_Q = self.arch >= 90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = ( + self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + ) # TODO: rescale_O_before_gemm self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) + ) + SharedStorage = self._get_shared_storage_cls() + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) + # TMA gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() - self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, self.sQ_layout) - self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) - self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) - tma_atom_Q, tma_tensor_Q = cpasync.make_tma_tile_atom( - gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast - ) - tma_atom_K, tma_tensor_K = cpasync.make_tma_tile_atom( + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } + tma_atom_Q, tma_tensor_Q = None, None + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_Q, + mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_padded), - 1 # No mcast for now + (self.tile_n, self.tile_hdim), + 1, # No mcast for now ) - tma_atom_V, tma_tensor_V = cpasync.make_tma_tile_atom( + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_v_padded), - 1 # No mcast for now - ) - tma_atom_O, tma_tensor_O = cpasync.make_tma_tile_atom( - gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast - ) - if cutlass.const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(cute.size(mQ.shape[0]) if mCuSeqlensQ is None else max_seqlen_q, self.m_block_size), + (self.tile_n, self.tile_hdimv), + 1, # No mcast for now + ) + tma_atom_O, tma_tensor_O = None, None + if const_expr(self.use_tma_O): + tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_O, + mO, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast + ) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3] if mCuSeqlensQ is None else mCuSeqlensQ.shape[0] - 1), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, ) - # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - # Right after this, we multiply by log2(e) before applying exp2. - # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) - # (assigning it to softmax_scale_log2). + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) LOG2_E = math.log2(math.e) - if cutlass.const_expr(not self.has_softcap): + if const_expr(self.score_mod is None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = cutlass.Float32(0.0) + softmax_scale = None else: - softmax_scale_log2 = softcap * LOG2_E - softcap_val = softmax_scale / softcap + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.kernel( - tma_tensor_Q if not self.pack_gqa else mQ, + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, tma_tensor_K, tma_tensor_V, - mO, - tma_tensor_O, + tma_tensor_O if const_expr(self.use_tma_O) else mO, mLSE, mCuSeqlensQ, mCuSeqlensK, @@ -1147,7 +1522,11 @@ def __call__( tma_atom_V, tma_atom_O, softmax_scale_log2, - softcap_val, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1157,16 +1536,20 @@ def __call__( self.gmem_tiled_copy_K, self.gmem_tiled_copy_V, self.gmem_tiled_copy_O, - # the compiler is unhappy about us using tiled_mma_qk/pv and setting the ACCUMULATE - # field inside a for loop, so we work around by creating multiple copies of the - # tiled_mma_qk/pv. - *((tiled_mma_qk, tiled_mma_pv) * 3), + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + tile_sched_params, + TileScheduler, SharedStorage, + aux_tensors, + fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], smem=SharedStorage.size_in_bytes(), stream=stream, + min_blocks_per_mp=1, ) @cute.kernel @@ -1176,7 +1559,6 @@ def kernel( mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, - mO_tma: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], @@ -1186,8 +1568,12 @@ def kernel( tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], - softmax_scale_log2: cutlass.Float32, - softcap_val: cutlass.Float32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1199,52 +1585,53 @@ def kernel( gmem_tiled_copy_O: cute.TiledCopy, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - tiled_mma_qk_copy: cute.TiledMma, - tiled_mma_pv_copy: cute.TiledMma, - tiled_mma_qk_copy1: cute.TiledMma, - tiled_mma_pv_copy1: cute.TiledMma, - SharedStorage: cutlass.Constexpr, + tiled_mma_pv_rs: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors=Optional[list[cute.Tensor]], + fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: - if cutlass.const_expr(not self.pack_gqa): - cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) - if cutlass.const_expr(self.use_tma_O): - cpasync.prefetch_descriptor(tma_atom_O) + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) # Mbarrier init mbar_ptr_Q = storage.mbar_ptr.data_ptr() - if warp_idx == 0: + if warp_idx == 1: # if tidx < 2: # # barrierO num threads should be self.num_mma_threads - # cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q, 1 if not self.pack_gqa else self.num_Q_load_threads) - # cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q + 1, self.num_mma_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread) - pipeline_kv_consumer_group = cutlass.utils.CooperativeGroup( - cutlass.utils.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread + ) + pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) - pipeline_k = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_k = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_K.data_ptr(), num_stages=self.num_stages, producer_group=pipeline_kv_producer_group, consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_k_bytes, + tx_count=self.tma_copy_bytes["K"], init_wait=False, ) - pipeline_v = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_v = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_V.data_ptr(), num_stages=self.num_stages, producer_group=pipeline_kv_producer_group, consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_v_bytes, + tx_count=self.tma_copy_bytes["V"], ) # /////////////////////////////////////////////////////////////////////////////// @@ -1253,408 +1640,800 @@ def kernel( # TODO: how to get sQ_pi for cp.async if pack_gqa? sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: - sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) - if cutlass.const_expr(sP_layout is not None): - # sP_pi = storage.sP.get_tensor(sP_layout) - sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) - sP_pi = cute.make_tensor(sP.iterator, sP_layout) - else: - sP, sP_pi = None - # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) + sP = None + if const_expr(sP_layout is not None): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + # reuse sQ's data iterator + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) - # Thread index, block index - tidx, _, _ = cute.arch.thread_idx() - m_block, head_idx, batch_idx = cute.arch.block_idx() block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, - self.qhead_per_kvhead if self.pack_gqa else 1, - ) - seqlen = SeqlenInfo( - batch_idx, mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], mK.shape[0], mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK - ) - # Can't early exit so we have to write it this way (under an if statement) - if mCuSeqlensQ is None or m_block * self.n_block_size < seqlen.seqlen_q: - if cutlass.const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 - if cutlass.const_expr(mCuSeqlensQ is None): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # TODO: return early if n_block_max == 0 - # if self.is_causal: - # if n_block_max <= 0: - # return - - if warp_idx < 4: # Producer - cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - head_idx_kv = head_idx // self.qhead_per_kvhead if not self.pack_gqa else head_idx - if cutlass.const_expr(mCuSeqlensK is None): - mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] - else: - mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] - gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) - gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) - if cutlass.const_expr(not self.pack_gqa): - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, - cute.make_layout(1), - cute.group_modes(sQ, 0, 2), - cute.group_modes(gQ, 0, 2), - ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, - cute.make_layout(1), - cute.group_modes(sK, 0, 2), - cute.group_modes(gK, 0, 2), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, - cute.make_layout(1), - cute.group_modes(sV, 0, 2), - cute.group_modes(gV, 0, 2), - ) - smem_pipe_write = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Producer, self.num_stages + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: # Producer + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + blocksparse_tensors, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + mQ, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + blocksparse_tensors, + aux_tensors, + fastdiv_mods, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + if warp_idx_in_wg == 0: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx ) - load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) - load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) - if warp_idx == 0: # Producer - # load_Q - if cutlass.const_expr(not self.pack_gqa): - with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes(mbar_ptr_Q, self.tma_copy_q_bytes) - cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) - for n_tile in cutlass.range_dynamic(n_block_max, unroll=2): - n_block = n_block_max - n_tile - 1 - load_K(n_block, smem_pipe_write=smem_pipe_write) - load_V(n_block, smem_pipe_write=smem_pipe_write) - smem_pipe_write.advance() - - else: # Consumer - cute.arch.warpgroup_reg_alloc(self.num_mma_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - tidx = tidx - 128 - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True + ) + # TODO: mcast + # TODO check warp_idx if we have 128 producer threads + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK ) - thr_mma_qk = tiled_mma_qk.get_slice(tidx) - wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) - tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) - tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) - tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if cutlass.const_expr(sP is not None) else None - tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) - acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) - - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_P = utils.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) - # tPsP = smem_thr_copy_P.partition_D(sP_pi) if cutlass.const_expr(sP_pi is not None) else None - tPsP = smem_thr_copy_P.partition_D(sP) if cutlass.const_expr(sP is not None) else None - # if cute.arch.thread_idx()[0] == 0: - # cute.printf(sP_pi.layout, sP_pi.iterator) - # cute.printf(sP.layout, sP.iterator) - # cute.printf(tPsP.layout, tPsP.iterator) - - self.mma_init() - - # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) - softmax.reset() - # group parameters for compute_one_n_block - mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) - smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if cutlass.const_expr(self.has_softcap): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) - - mask = AttentionMask( - self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, - self.qhead_per_kvhead if self.pack_gqa else 1 + load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV ) - mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal + load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + self.use_tma_Q, + self.tma_copy_bytes["Q"], + self.intra_wg_overlap, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + # softmax: Softmax, + # acc_O: cute.Tensor, + mQ: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + aux_tensors: Optional[list], + fastdiv_mods=None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) + tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) + if const_expr(self.mma_pv_is_rs): + acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) + tOrP = cute.make_fragment( + utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype + ) + else: + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) + tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + # if cute.arch.thread_idx()[0] == 0: + # cute.printf(sP_pi.layout, sP_pi.iterator) + # cute.printf(sP.layout, sP.iterator) + # cute.printf(tPsP.layout, tPsP.iterator) + + self.mma_init() + + acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, + mma_qk_fn=mma_qk_fn, + tiled_mma_pv_rs=tiled_mma_pv_rs, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + check_inf=True, + ) + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + softmax=softmax, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + ) + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + + # shape: (atom_v_m * rest_m) + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + softmax_scale=softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) - compute_one_n_block = partial( - self.compute_one_n_block_intrawg_overlap if cutlass.const_expr(self.intra_wg_overlap) else self.compute_one_n_block, - pipeline_k=pipeline_k, pipeline_v=pipeline_v, - mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, + mma_one_n_block = partial( + mma_one_n_block_all, + softmax=softmax, + score_mod_fn=score_mod_fn, + ) + # Load Q if not TMA_Q + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead ) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) - # Load Q if PackGQA - if cutlass.const_expr(self.pack_gqa): - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) - # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, - # headdim=mQ.shape[1]) - pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) - - n_block = n_block_max - 1 - smem_pipe_read = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Consumer, self.num_stages - ) - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=0) - # For performance reason, we separate out two kinds of iterations: - # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. - # We also need masking on S if it's causal, for the last several blocks. + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + q_consumer_phase ^= 1 + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True + O_should_accumulate = False + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== # First iteration with seqlen masking - if cutlass.const_expr(self.intra_wg_overlap): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - pipeline_k.consumer_wait(smem_pipe_read) - sm90_utils.gemm( - tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, smem_pipe_read.index], - zero_init=True, wg_wait=0 + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, ) - pipeline_k.consumer_release(smem_pipe_read) - scoremod_premask_fn(acc_S) - mask_fn(acc_S, n_block=n_block, mask_seqlen=True) - softmax.online_softmax(acc_S, is_first=True, check_inf=True) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_thr_copy_P.retile(rP) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - acc_O.fill(0.0) + # acc_O.fill(0.0) else: self.warp_scheduler_barrier_sync() - compute_one_n_block( - n_block, smem_pipe_read, tiled_mma_qk, tiled_mma_pv, - is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) - smem_pipe_read.advance() + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal): + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) - # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): - n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask - compute_one_n_block( - n_block, smem_pipe_read, tiled_mma_qk_copy, tiled_mma_pv_copy, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) - smem_pipe_read.advance() + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block, unroll=1): - compute_one_n_block( - n_block - n_tile - 1, smem_pipe_read, tiled_mma_qk_copy1, tiled_mma_pv_copy1, - check_inf=False, + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) - smem_pipe_read.advance() + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True # Last "half" iteration - if cutlass.const_expr(self.intra_wg_overlap): - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=False, wg_wait=-1 + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, ) - warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read) - smem_pipe_read.advance() + O_should_accumulate = True else: self.warp_scheduler_barrier_arrive() - # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize() - softmax.rescale_O(acc_O, row_scale) - - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - # reuse sQ's data iterator - sO_pi = cute.make_tensor(sQ.iterator, sO_layout) - # TODO: idk why using not using sO_pi is faster - sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) - self.epilogue( - acc_O, softmax.row_sum, mO if not self.use_tma_O else mO_tma, mLSE, sO, seqlen, - gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, - is_varlen=cutlass.const_expr(mCuSeqlensQ is not None), + else: + # ========================================== + # Block sparsity + # ========================================== + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + fastdiv_mods, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, ) + # Handle empty case (when no blocks to process) + if not processed_any: + softmax.reset() + acc_O.fill(0.0) + + sink_val = None + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_fragment_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.tile_m + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit - def compute_one_n_block( + def first_half_block_overlap( self, - n_block: cutlass.Int32, - smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, - mma_params: SimpleNamespace, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, - scoremod_premask_fn: Callable, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + # if pv gemm not rs + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + + # Advance state for next iteration + kv_consumer_state.advance() + + return kv_consumer_state + + @cute.jit + def mma_one_n_block( + self, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - sm90_utils.gemm( - tiled_mma_qk, acc_S, mma_params.tSrQ, - mma_params.tSrK[None, None, None, smem_pipe_read.index], - zero_init=True, wg_wait=-1 - ) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) - scoremod_premask_fn(acc_S) - if cutlass.const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) - # if cute.arch.thread_idx()[0] == 0: - # cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_copy_params.smem_thr_copy_P.retile(rP) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(mma_params.acc_O, row_scale) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=is_first_n_block, wg_wait=0 - ) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read @cute.jit - def compute_one_n_block_intrawg_overlap( + def mma_one_n_block_intrawg_overlap( self, - n_block: cutlass.Int32, - smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, - mma_params: SimpleNamespace, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, - scoremod_premask_fn: Callable, + score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): - smem_pipe_read_k = smem_pipe_read.clone() - smem_pipe_read_k.advance() - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - pipeline_k.consumer_wait(smem_pipe_read_k, pipeline_k.consumer_try_wait(smem_pipe_read_k)) + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - sm90_utils.gemm( - tiled_mma_qk, acc_S, mma_params.tSrQ, - mma_params.tSrK[None, None, None, smem_pipe_read_k.index], - zero_init=True, wg_wait=-1 - ) - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=False, wg_wait=-1 - ) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) - pipeline_k.consumer_release(smem_pipe_read_k) - scoremod_premask_fn(acc_S) - if cutlass.const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: - # cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_copy_params.smem_thr_copy_P.retile(rP) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(mma_params.acc_O, row_scale) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read @cute.jit def mma_init(self): warp_group_idx = utils.canonical_warp_group_idx(sync=False) - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): if warp_group_idx == 1: - utils.barrier_arrive( + cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * self.num_threads_per_warp_group, ) + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + def warp_scheduler_barrier_sync(self): - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), - number_of_threads=2 * self.num_threads_per_warp_group + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, ) def warp_scheduler_barrier_arrive(self): - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): assert self.num_mma_warp_groups in [2, 3] cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 - next_wg = 1 - cur_wg if self.num_mma_warp_groups == 2 else (cur_wg + 1 if cur_wg < self.num_mma_warp_groups - 1 else 0) - utils.barrier_arrive( + if const_expr(self.num_mma_warp_groups == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_mma_warp_groups + cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, - ) - - # @cute.jit - def load_K( - self, - tma_atom: cute.CopyAtom, - tKgK: cute.Tensor, - tKsK: cute.Tensor, - pipeline: cutlass.utils.PipelineAsync, - block: cutlass.Int32, - smem_pipe_write: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, - ): - # TODO: mcast - # TODO check warp_idx if we have 128 producer threads - pipeline.producer_acquire(smem_pipe_write) - cute.copy( - tma_atom, - tKgK[None, block], - tKsK[None, smem_pipe_write.index], - tma_bar_ptr=pipeline.producer_get_barrier(smem_pipe_write) - ) + ) \ No newline at end of file diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py new file mode 100644 index 00000000000..f97e127175d --- /dev/null +++ b/flash_attn/cute/flash_fwd_combine.py @@ -0,0 +1,704 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + + +class FlashAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 8, + k_block_size: int = 64, + log_max_splits: int = 4, + num_threads: int = 256, + stages: int = 4, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param m_block_size: m block size + :param k_block_size: k block size + :param log_max_splits: log2 of maximum splits + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.m_block_size = m_block_size + self.k_block_size = k_block_size + self.max_splits = 1 << log_max_splits + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if m_block_size % 8 != 0: + return False + max_splits = 1 << log_max_splits + if max_splits > 256: + return False + if (m_block_size * max_splits) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tOpartial_layout, + vOpartial_layout, # 4 vals per store + ) + + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.m_block_size % 128 == 0 + else ( + 64 + if self.m_block_size % 64 == 0 + else ( + 32 + if self.m_block_size % 32 == 0 + else (16 if self.m_block_size % 16 == 0 else 8) + ) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + ) + + # O partial shared memory layout (simple layout for pipeline stages) + self.smem_layout_o = cute.make_ordered_layout( + (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO_partial, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mO_partial, mO) + ] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + # Create FastDivmodDivisor objects for efficient division + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sO = storage.sO.get_tensor(smem_layout_o) + + # Handle semaphore reset + if const_expr(semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and batch_idx == cute.arch.grid_dim()[2] - 1 + ): + semaphore_to_reset[0] = 0 + + # Get number of splits + num_splits = ( + num_splits_dynamic_ptr[batch_idx] + if const_expr(num_splits_dynamic_ptr is not None) + else mLSE_partial.shape[1] + ) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + # Extract number of heads (head index will be determined dynamically) + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + + # Early exit for single split if dynamic + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + const_expr(not varlen) or m_block * self.m_block_size < max_idx + ): + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + + if const_expr(cu_seqlens is None): + # mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] + mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3) + else: + # mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + + # Create identity tensor for coordinate tracking + cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + # Load LSE partial values + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] # Get m coordinate + idx = m_block * self.m_block_size + mi + if idx < max_idx: + # Calculate actual sequence position and head using FastDivmodDivisor + if const_expr(not varlen): + head_idx, m_idx = divmod(idx, seqlen_divmod) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] # Get split coordinate + if si < num_splits: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + if const_expr(cu_seqlens is None): + # mO_partial_cur = mO_partial[None, None, None, None, batch_idx] + mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4) + else: + # mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial) + + # Precompute these values to avoid recomputing them in the loop + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_fragment(num_rows, cutlass.Int32) + tOhidx = cute.make_fragment(num_rows, cutlass.Int32) + tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate + idx = m_block * self.m_block_size + mi + if const_expr(not varlen): + tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) + else: + tOhidx[m] = idx // seqlen + tOmidx[m] = idx - tOhidx[m] * seqlen + tOrOptr[m] = utils.elem_pointer_i64( + mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]) + ).toint() + if idx >= max_idx: + tOhidx[m] = -1 + + tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + if const_expr(not self.is_even_k): + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = utils.warp_reduce( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + op=cute.arch.fmax, + width=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E)) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) + lse_sum[m] = utils.logf(lse_sum_cur) + lse_max + # Normalize scales + inv_sum = ( + 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.m_block_size: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + # mLSE_cur = mLSE[None, None, batch_idx] + mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2) + else: + # mLSE_cur = cute.domain_offset((offset, 0), mLSE) + mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.m_block_size + mi + if idx < max_idx: + if const_expr(not varlen): + head_idx, m_idx = divmod(idx, seqlen_divmod) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_cur[m_idx, head_idx] = lse_sum[m] + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_fragment(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # =============================== + # Step 7: Write final O to gmem + # =============================== + + rO = cute.make_fragment_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + if const_expr(cu_seqlens is None): + # mO_cur = mO[None, None, None, batch_idx] + mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3) + else: + # mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_i64((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # Write final results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + mO_cur_copy = cute.tiled_divide( + mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOpO: cute.Tensor, + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy( + gmem_tiled_copy_O_partial, + # mO_partial_cur_copy[None, k_idx, split], + utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], + tOsO_partial_cur[None, m, k], + ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py new file mode 100644 index 00000000000..645ad97b003 --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -0,0 +1,2678 @@ +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128, (192, 128). +# - varlen +# - sliding window +# - split-kv +# Unsupported features that will be added later: +# - page size != 128 +# - more hdim (192, 256) +# Based on the cutlass example and cute-dsl example: +# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py + +import enum +import math +from typing import Type, Tuple, Callable, Optional, Literal +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic + +from flash_attn.cute.paged_kv import PagedKVManager +import flash_attn.cute.utils as utils +from flash_attn.cute import copy_utils +import flash_attn.cute.pipeline as pipeline +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_block_count, + produce_block_sparse_loads_sm100, + softmax_block_sparse_sm100, + handle_block_sparse_empty_tile_correction_sm100, +) +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from cutlass.cute import FastDivmodDivisor +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) + + +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +# WarpSchedulerWG1 = enum.auto() +# WarpSchedulerWG2 = enum.auto() +# WarpSchedulerWG3 = enum.auto() +# PFull = enum.auto() +# PEmpty = enum.auto() + + +class FlashAttentionForwardSm100: + arch = 100 + + def __init__( + self, + # dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_causal: bool = False, + is_local: bool = False, + is_split_kv: bool = False, + pack_gqa: bool = False, + m_block_size: int = 128, + n_block_size: int = 128, + is_persistent: bool = True, + score_mod: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, + ): + self.use_tma_KV = not paged_kv_non_tma + # self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.q_stage = 2 + assert self.q_stage in [1, 2] + + # 2 Q tile per CTA + self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.is_varlen_q = is_varlen_q + self.use_correction_warps_for_epi = is_varlen_q + self.qhead_per_kvhead = qhead_per_kvhead + self.is_split_kv = is_split_kv + self.pack_gqa = pack_gqa + if pack_gqa: + assert m_block_size % self.qhead_per_kvhead == 0, ( + "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + ) + assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( + "SplitKV is not supported for hdim >= 192" + ) + self.score_mod = score_mod + self.mask_mod = mask_mod + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 + # Does S1 need to wait for S0 to finish + # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + self.s0_s1_barrier = False + self.overlap_sO_sQ = ( + (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or + (self.head_dim_v_padded >= 128 and self.is_split_kv) + ) + if self.overlap_sO_sQ: + self.is_persistent = False + + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14,) + self.empty_warp_ids = (15,) + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + *self.load_warp_ids, + *self.epilogue_warp_ids, + *self.empty_warp_ids, + ) + ) + + if not self.use_tma_KV: + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + if self.use_correction_warps_for_epi: + self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids + self.epilogue_warp_ids = self.correction_warp_ids + elif self.is_varlen_q: # fallback + self.epilogue_warp_ids = (13, 14) + + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 + self.tmem_o_offset = [ + self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded + for i in range(self.q_stage) + ] # e.g., 256, 384 + self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded + assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + self.tmem_s_to_p_offset = self.n_block_size // 2 + self.tmem_p_offset = [ + self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) + ] # 0, 128 + + # vec buffer for row_max & row_sum + self.tmem_vec_offset = self.tmem_s_offset + + if self.head_dim_padded < 96: + self.num_regs_softmax = 200 + self.num_regs_correction = 64 + self.num_regs_other = 48 + else: + # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + self.num_regs_softmax = 200 + # self.num_regs_softmax = 176 + # self.num_regs_correction = 96 + # self.num_regs_correction = 80 + # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + self.num_regs_correction = 64 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + # self.num_regs_other = 80 + self.num_regs_other = 48 + # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.epi_stage = 2 + # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. + # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is + # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be + # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, + # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. + self.uneven_kv_smem = ( + self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + ) + self.uneven_kv_smem_offset = ( + self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + if self.uneven_kv_smem + else 0 + ) + assert self.uneven_kv_smem_offset % 1024 == 0 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + """ + # setup static attributes before smem/grid/tma computation + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = mO.element_type + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose)) + # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + if const_expr(self.is_split_kv): + O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] + num_splits = mO.shape[0] + else: + O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + num_splits = Int32(1) + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if const_expr(mLSE is not None) + else None + ) + # (s, d, h, b) -> (d, s, h, b) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mQ is not supported") + if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mK is not supported") + if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of mV is not supported") + + # check type consistency + if const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None + # This can be tuned + self.e2e_freq = 16 + if const_expr( + self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa + ): + self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & mK-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.mma_tiler_qk[:2], + ) + tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.mma_tiler_pv[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_qk.thr_id.shape,), + ) + + self.epi_tile = self.mma_tiler_pv[:2] + + sQ_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_qk, + self.mma_tiler_qk, + self.q_dtype, + self.q_stage, + ) + sK_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_qk, + self.mma_tiler_qk, + self.k_dtype, + self.kv_stage, + ) + tP_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pv, + self.mma_tiler_pv, + self.q_dtype, + self.acc_stage, + ) + sV_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pv, + self.mma_tiler_pv, + self.v_dtype, + self.kv_stage, + ) + sO_layout = sm100_utils_basic.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, + ) + if const_expr(not self.same_hdim_kv_padded): + # sK and sV are using the same physical smem so we need to adjust the stride so that they line up + stride_sK = const_expr( + max(sK_layout.outer.stride[-1], 0) + ) # take max to turn tuple to Int32 + stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) + stage_stride = const_expr( + max(stride_sK, stride_sV) + if not self.uneven_kv_smem + else (stride_sK + stride_sV) // 2 + ) + sK_layout = cute.make_composed_layout( + sK_layout.inner, + 0, + cute.make_layout( + (*sK_layout.outer.shape[:-1], self.kv_stage), + stride=(*sK_layout.outer.stride[:-1], stage_stride), + ), + ) + sV_layout = cute.make_composed_layout( + sV_layout.inner, + 0, + cute.make_layout( + (*sV_layout.outer.shape[:-1], self.kv_stage), + stride=(*sV_layout.outer.stride[:-1], stage_stride), + ), + ) + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mO.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, sQ_layout), + ("K", mK, sK_layout), + ("V", mV, sV_layout), + ] + } + + # TMA load for Q + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + if const_expr(self.use_tma_KV): + # TMA load for K + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + else: + tma_atom_K = None + tma_atom_V = None + + o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) + + self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, + mO, + cute.select(sO_layout, mode=[0, 1]), + o_cta_v_layout, + ) + gmem_tiled_copy_O = None + else: + tma_atom_O = None + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.o_dtype.width + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.o_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + vO_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + if const_expr(self.is_causal or self.is_local): + TileScheduler = SingleTileLPTScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_persistent) + else StaticPersistentTileScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits, + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mQ.shape[1], + mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=self.cta_tiler[:2], + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=self.is_causal or self.is_local, + is_split_kv=self.is_split_kv, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + self.mbar_load_q_full_offset = 0 + self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage + self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage + self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage + self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage + self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + 2 + self.mbar_O_full_offset = self.mbar_S_full_offset + 2 + self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + 2 + self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + 2 + self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.epi_stage + self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.epi_stage + self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 + self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 + self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 + self.mbar_total = self.mbar_P_full_2_offset + 2 + + sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + sQ_size = ( + cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else + cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) + ) + + @cute.struct + class SharedStorage: + # m_barriers for pipelines + mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + # store row max and row sum + sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, sO_size], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, sQ_size], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softmax_scale = None + else: + # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if cutlass.const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): + raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + + # Launch the kernel synchronously + self.kernel( + mQ, + mK, + mV, + mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + sQ_layout, + sK_layout, + tP_layout, + sV_layout, + sO_layout, + gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + num_splits, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q + mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table + mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Float32 | None, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_O: Optional[cute.TiledCopy], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + num_splits: Int32, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + if const_expr(tma_atom_K is not None): + cpasync.prefetch_descriptor(tma_atom_K) + if const_expr(tma_atom_V is not None): + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_O is not None): + cpasync.prefetch_descriptor(tma_atom_O) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mbar_ptr = storage.mbar_ptr.data_ptr() + # Use the first N warps to initialize barriers + if warp_idx == 1: + # Init "full" barrier with number of producers, "empty" barrier with number of consumers + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_full_offset + i, 1 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) + ) + if warp_idx == 2: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4 + ) + if warp_idx == 3: + if const_expr(self.s0_s1_barrier): + for i in cutlass.range_constexpr(8): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE + ) + if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_full_offset + i, + cute.arch.WARP_SIZE * len(self.correction_warp_ids), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_empty_offset + i, + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), + ) + if warp_idx == 5: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, + cute.arch.WARP_SIZE + * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) + ) + if warp_idx == 6: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_2_offset + i, + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), + ) + if warp_idx == 7: + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_tmem_dealloc_offset, + cute.arch.WARP_SIZE + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync + pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + if const_expr(not self.overlap_sO_sQ): + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + else: + sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer) + + sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) + + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) + tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) + + pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) + + tStSs = tuple( + cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2) + ) + tOtOs = tuple( + cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage) + ) + + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + + tOrPs = [ + cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], + tOrP.layout, + ) + for stage in range(2) + ] + + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], + self.cta_tiler[1], + self.is_causal, + self.is_local, + self.is_split_kv, + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(len(self.empty_warp_ids) > 0): + if warp_idx == self.empty_warp_ids[0]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + if const_expr(len(self.empty_warp_ids) > 1): + if warp_idx == self.empty_warp_ids[1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + assert len(self.empty_warp_ids) <= 2 + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + if warp_idx == self.mma_warp_id: + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + tStSs, + tOtOs, + tOrPs, + pipeline_kv, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(not self.use_correction_warps_for_epi): + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g( + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.correction_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + softmax_loop = partial( + self.softmax_loop, + softmax_scale_log2=softmax_scale_log2, + softmax_scale=softmax_scale, + thr_mma_qk=thr_mma_qk, + sScale=sScale, + mLSE=mLSE, + learnable_sink=learnable_sink, + mbar_ptr=mbar_ptr, + block_info=block_info, + num_splits=num_splits, + SeqlenInfoCls=SeqlenInfoCls, + AttentionMaskCls=AttentionMaskCls, + TileSchedulerCls=TileSchedulerCls, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + blocksparse_tensors=blocksparse_tensors, + ) + + if const_expr(not self.s0_s1_barrier): + stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) + softmax_loop( + stage=stage, + tStSi=cute.make_tensor( + tStS.iterator + + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), + tStS.layout, + ), + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + else: + # If there's s0_s1_barrier, it's faster to have 2 WGs having different code + if warp_idx < self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout) + softmax_loop(stage=0, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout) + softmax_loop(stage=1, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + self.correction_loop( + thr_mma_qk, + thr_mma_pv, + tStS, + tOtOs, + sScale, + mO, + mLSE, + sO, + learnable_sink, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + softmax_scale_log2, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + return + + @cute.jit + def load( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + ): + num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE + tidx = cute.arch.thread_idx()[0] % num_load_threads + q_producer_phase = Int32(1) + kv_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.kv_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + if const_expr(mPageTable is None): + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) + mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) + else: + # Need to keep batch coord None since we'll index into it with page idx + mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) + ) + tSgQ = thr_mma_qk.partition_A(gQ) + tSgK = thr_mma_qk.partition_B(gK) + tOgV = thr_mma_pv.partition_B(gV) + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ + ) + + if const_expr(self.use_tma_KV): + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + paged_kv_manager = None + else: + page_size = mK.shape[0] + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmodDivisor(page_size), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.n_block_size, + self.head_dim_padded, + self.head_dim_v_padded, + num_load_threads, + mK.element_type, + ) + tKsK, tKgK = None, None + tVsV, tVgV = None, None + + load_Q = partial( + self.load_Q, + load_Q_fn, + mbar_ptr + self.mbar_load_q_full_offset, + mbar_ptr + self.mbar_load_q_empty_offset, + phase=q_producer_phase, + ) + # We have to use mbarrier directly in the load for KV instead of replying on + # pipeline_kv, because we could have different number of TMA bytes for K and V + load_K = partial( + self.load_KV, + tma_atom_K, + tKgK, + tKsK, + paged_kv_manager, + sK, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="K", + ) + load_V = partial( + self.load_KV, + tma_atom_V, + tVgV, + tVsV, + paged_kv_manager, + sV, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="V", + ) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 + page_idx = ( + mPageTable[batch_idx, n_block_first] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block_first) + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + kv_producer_state.advance() + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + + else: + kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + self.q_stage, + q_producer_phase, + ) + + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.core.ThrMma, + tiled_mma_pv: cute.core.ThrMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tStSs: Tuple[cute.Tensor, cute.Tensor], + tOtOs: tuple[cute.Tensor], + tOrPs: Tuple[cute.Tensor, cute.Tensor], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + ): + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tOrV = tiled_mma_pv.make_fragment_B(sV) + if const_expr(self.q_stage == 2): + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + else: + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 0]) + + qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op + + gemm_Si = [ + partial( + sm100_utils.gemm_ptx_partial, + qk_mma_op, + self.tmem_s_offset[stage], + tSrQs[stage], + sA=sQ[None, None, None, stage], + zero_init=True, + ) + for stage in range(2) + ] + gemm_Pi = [ + partial( + sm100_utils.gemm_ptx_partial, + pv_mma_op, + self.tmem_o_offset[stage if self.q_stage == 2 else 0], + tOrPs[stage], + sA=None, + ) + for stage in range(2) + ] + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage + ) + P_full_O_rescaled_phase = Int32(0) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + block_iter_count = Int32(0) + process_tile = False + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + process_tile = block_iter_count > Int32(0) + else: + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + block_iter_count = n_block_max - n_block_min + if const_expr(not self.is_split_kv): + process_tile = True + else: + process_tile = n_block_min < n_block_max + + if process_tile: + for stage in cutlass.range_constexpr(self.q_stage): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) + # 2. wait for K0 + if const_expr(stage == 0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + block_loop_count = block_iter_count - 1 + O_should_accumulate = False + for i in cutlass.range(block_loop_count, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase, + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if const_expr(stage == 1): + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if const_expr(stage == 0): + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 + with cute.arch.elect_one(): + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected Oi_partial and Pi + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warps, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + + # for both softmax0 and softmax1 warp group + @cute.jit + def softmax_loop( + self, + stage: int | Int32, + softmax_scale_log2: Float32, + softmax_scale: Float32, + thr_mma_qk: cute.core.ThrMma, + tStSi: cute.Tensor, + sScale: cute.Tensor, + mLSE: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + """ + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE + # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + * (len(self.softmax0_warp_ids)) + ) + + tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + + tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tStP_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + Float32, + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tStSi) + + tmem_store_scale_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), + Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( + tidx + ) + + tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), + Float32, + ) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + mma_si_consumer_phase = Int32(0) + si_corr_producer_phase = Int32(1) + s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) + + # self.warp_scheduler_barrier_init() + + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + shared_mask_kwargs = dict( + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + ) + mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None + mask_fn = partial( + mask.apply_mask_sm100, + mask_mod=mask_mod, + fastdiv_mods=fastdiv_mods, + **shared_mask_kwargs, + ) + if const_expr(self.use_block_sparsity): + # Full blocks dont need mask_mod + mask_fn_none = partial( + mask.apply_mask_sm100, + mask_mod=None, + fastdiv_mods=fastdiv_mods, + **shared_mask_kwargs, + ) + else: + mask_fn_none = None + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + if const_expr(self.use_block_sparsity): + tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = tile_block_count > Int32(0) + else: + tile_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + if has_work: + # Softmax acts as the producer: wait until correction signals the stage is empty + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) + si_corr_producer_phase ^= 1 + + # Block sparse or dense iteration + if const_expr(self.use_block_sparsity): + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + empty_tile, + ) = softmax_block_sparse_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + softmax_step, + mask_fn, + mask_fn_none, + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.q_stage, + Int32(stage), + ) + if not empty_tile: + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + else: + if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking (but may still need mask_mod) + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + if const_expr(self.mask_mod is not None): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + else: + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + ) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # Dense path always writes scale / signals + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # # Write LSE to gmem + # if const_expr(mLSE is not None): + # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] + # scale = ( + # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) + # ) + # LN2 = math.log(2.0) + # lse = ( + # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 + # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + # ) + # if const_expr(not seqlen.has_cu_seqlens_q): + # mLSE_cur = mLSE[None, head_idx, batch_idx] + # else: + # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) + # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + # gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def softmax_step( + self, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + n_block: Int32, + softmax: SoftmaxSm100, + mbar_ptr: cute.Pointer, + mbar_s0_s1_sequence_offset: Int32, + thr_mma_qk: cute.core.ThrMma, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_scale: cute.CopyAtom, + tStS_t2r: cute.Tensor, + tStScale_r2t: cute.Tensor, + tStP_r2t: cute.Tensor, + sScale: cute.Tensor, + stage: int | Int32, + batch_idx: Int32, + head_idx: Int32, + m_block: Int32, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + mask_fn: Optional[Callable] = None, + is_first: bool = False, + ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32))) + + # Wait for Si + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) + tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if cutlass.const_expr(self.score_mod is not None): + self.apply_score_mod( + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + aux_tensors, + fastdiv_mods, + ) + + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) + + if const_expr(not is_first): + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * self.m_block_size] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + # Notify correction wg that row_max is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + # print(tSrS_t2r) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + # Sequence barrier wait + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase + ) + tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq, + ) + # Sequence barrier arrive + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + # print(tSrP_r2t_f32, tStP_r2t) + # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + for i in cutlass.range_constexpr( + cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) + ): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that the 2nd half of P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + # acc_scale = cute.arch.exp2(acc_scale_) + return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 + + @cute.jit + def correction_loop( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOtOs: tuple[cute.Tensor], + sScale: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: cute.CopyAtom, + mbar_ptr: cute.Pointer, + softmax_scale_log2: Float32, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) + tStScales = tuple( + cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) + for stage in range(2) + ) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), + self.qk_acc_dtype, + ) + thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) + + tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(2)] + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape + + # First iter: no correction is required + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) + + softmax_corr_consumer_phase = Int32(0) + o_corr_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + + # Default LSE to -inf for invalid split_idx tiles + stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage + + if const_expr(self.use_block_sparsity): + total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = total_block_count > Int32(0) + else: + total_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + + if has_work: + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase + ) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) + for i in cutlass.range(total_block_count - 1, unroll=1): + for stage in cutlass.range_constexpr(2): + # wait for S0 / S1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[tidx + stage * self.m_block_size] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale( + thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + # End of seqlen_corr_loop_steps + + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + row_sum = sScale[tidx + stage * self.m_block_size] + if const_expr(mLSE is not None or learnable_sink is not None): + row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + else: + row_max = None + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + sink_val = learnable_sink_val[stage] + if const_expr(not self.is_split_kv) or split_idx == 0: + if row_max == -Float32.inf: + # It's possible to have an empty row with splitKV. + row_max = sink_val * (LOG2_E / softmax_scale_log2) + row_sum = Float32(1.0) + else: + row_sum += utils.exp2f( + sink_val * LOG2_E - row_max * softmax_scale_log2 + ) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase + ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) + self.correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + stage, + m_block, + seqlen.seqlen_q, + scale, + sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, + ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + else: + # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781 + if const_expr(self.use_correction_warps_for_epi): + gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O + else: + gmem_tiled_copy_O_for_empty_tile = None + if const_expr(self.use_block_sparsity): + ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) = handle_block_sparse_empty_tile_correction_sm100( + tidx, + self.q_stage, + self.m_block_size, + self.qhead_per_kvhead, + self.pack_gqa, + self.is_split_kv, + learnable_sink, + mLSE, + seqlen, + m_block, + head_idx, + batch_idx, + split_idx, + sScale, + stats, + self.correction_epilogue, + thr_mma_pv, + tOtOs, + sO, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.mbar_corr_epi_full_offset, + self.mbar_corr_epi_empty_offset, + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + softmax_scale_log2, + mO_cur, + gO, + gmem_tiled_copy_O_for_empty_tile, + ) + + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(self.is_split_kv): + mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx] + else: + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) + if const_expr(self.is_split_kv): + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx]) + else: + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + gLSE = cute.local_tile( + mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) + ) + row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] + # if tidx == 0 and stage <= 1: + # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + """ + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count = self.head_dim_v_padded // corr_tile_size + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) + for i in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def correction_epilogue( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + tidx: Int32, + stage: Int32, + m_block: Int32, + seqlen_q: Int32, + scale: Float32, + sO: cute.Tensor, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilogue function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size))) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice( + tidx + ) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): + tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] + tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) + cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + if const_expr(self.use_correction_warps_for_epi): + assert(not self.use_tma_O) + assert(gmem_tiled_copy_O is not None) + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO, self.o_dtype) + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen_q, + ) + + @cute.jit + def epilogue_s2g( + self, + mO: cute.Tensor, + sO: cute.Tensor, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: int, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + epi_consumer_phase = Int32(0) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(self.use_tma_O): + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) + # 2. copy O0 / O1 to gmem + store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) + cute.arch.cp_async_bulk_commit_group() + for stage in cutlass.range_constexpr(self.q_stage): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + else: + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + epi_consumer_phase ^= 1 + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + def load_Q( + self, + load_Q_fn: Callable, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + stage: int, + phase: Int32, + ): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"]) + load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage) + + @cute.jit + def load_KV( + self, + tma_atom: Optional[cute.CopyAtom], + tXgX: Optional[cute.Tensor], + tXsX: Optional[cute.Tensor], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + producer_state: cutlass.pipeline.PipelineState, + K_or_V: Literal["K", "V"], + page_idx: Optional[Int32] = None, + ): + assert K_or_V in ("K", "V") + stage, phase = producer_state.index, producer_state.phase + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + if const_expr(K_or_V == "K" and self.uneven_kv_smem): + # Before this round, the smem location was occupied by V, which is smaller than + # K. So we need to wait for the stage after that (stage 1) to be empty as well. + if stage == 0: + cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) + + if const_expr(self.use_tma_KV): + assert ( + tXgX is not None and + tXsX is not None and + tma_atom is not None + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V], + ) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + else: + assert paged_kv_manager is not None + paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage) + + @cute.jit + def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): + if const_expr(self.uneven_kv_smem): + # smem layout is [smem_large, smem_small, smem_large], and the current stride is + # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if + # phase == 0, or left by offset if phase == 1. + offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + return cute.make_tensor(sX.iterator + offset, sX.layout) + else: + return sX + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + if self.use_tma_KV: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) + ) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + ) + else: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE + ) + return cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + barrier_storage=load_kv_mbar_ptr, + ) + + # @cute.jit + # def warp_scheduler_barrier_init(self): + # warp_group_idx = utils.canonical_warp_group_idx(sync=False) + # if warp_group_idx == 0: + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # ) + + # def warp_scheduler_barrier_sync(self): + # cute.arch.barrier( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # number_of_threads=2 * 128 + # ) + + # def warp_scheduler_barrier_arrive(self): + # cur_wg = utils.canonical_warp_group_idx(sync=False) + # next_wg = 1 - cur_wg + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # ) + + @cute.jit + def apply_score_mod( + self, + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + """Apply score modification for SM100 (constant q_idx).""" + # Prepare index tensor with extra partition + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) + tScS = thr_mma_qk.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + + # Shared q_idx for all scores + q_idx_logical = tScS_t2r[0][0] + + # For Pack-GQA, compute the logical head index for this tile + if cutlass.const_expr(self.pack_gqa): + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_physical = q_idx_logical + q_idx_logical = q_physical // self.qhead_per_kvhead + head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead + head_idx = head_idx * self.qhead_per_kvhead + head_offset + + if cutlass.const_expr(aux_tensors is not None): + seqlen_q_divmod, _ = fastdiv_mods + _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod) + + apply_score_mod_inner( + tSrS_t2r, + tScS_t2r, + self.score_mod, + batch_idx, + head_idx, + softmax.softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + constant_q_idx=q_idx_logical, + qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, + ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index d42c33e76e7..c6a1c301904 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -1,9 +1,15 @@ # Copyright (c) 2025, Tri Dao. +from typing import Type, Union, Optional import cutlass import cutlass.cute as cute +from cutlass import Int32, Float32, Boolean, const_expr from cutlass.cute.nvgpu import warpgroup +from cutlass.cutlass_dsl import Numeric, dsl_user_op +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_og +@cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -14,14 +20,82 @@ def gemm( # A_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: - if swap_AB: + if const_expr(swap_AB): gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) else: warpgroup.fence() - tiled_mma.set(warpgroup.Field.ACCUMULATE, not zero_init) + # We make a new mma_atom since we'll be modifying its attribute (accumulate). + # Otherwise the compiler complains "operand #0 does not dominate this use" + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): - cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) - tiled_mma.set(warpgroup.Field.ACCUMULATE, True) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + mma_atom.set(warpgroup.Field.ACCUMULATE, True) warpgroup.commit_group() - if cutlass.const_expr(wg_wait >= 0): + if const_expr(wg_wait >= 0): warpgroup.wait_group(wg_wait) + + +def gemm_zero_init( + tiled_mma: cute.TiledMma, + shape: cute.Shape, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> cute.Tensor: + if const_expr(swap_AB): + return gemm_zero_init( + tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False + ) + else: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc + + +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: Boolean, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + + +@dsl_user_op +def make_smem_layout( + dtype: Type[Numeric], + layout: LayoutEnum, + shape: cute.Shape, + stage: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] + smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), + dtype, + ) + order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2) + smem_layout_staged = cute.tile_to_shape( + smem_layout_atom, + cute.append(shape, stage) if const_expr(stage is not None) else shape, + order=order if const_expr(stage is not None) else order[:2], + ) + return smem_layout_staged diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9a5bd894b56..651e9393135 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,18 +1,28 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-06-01] Initial version in Cute-DSL. -# Only support basic forward and backward pass for FlashAttention, optimized for Ampere. -# Lightly tested with headdim 128. -# Features not supported yet: +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. + +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128. +# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) # - varlen # - sliding window +# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) + +# Features not supported yet: # - split (i.e. FlashDecoding) # - tuned block sizes # - paged KV # - append KV to existing KV cache # - FP8 +# - bwd pass optimized for Hopper/Blackwell import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable import torch @@ -24,10 +34,19 @@ from flash_attn.cute import utils from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80, FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.block_sparsity import ( + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, + normalize_block_sparse_tensors, +) def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -40,6 +59,16 @@ def maybe_contiguous(x): } +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): + # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if num_n_blocks <= 4: + return 1 + + # NOTE: We should revisit this heuristic after persistence is supported for split KV. + # Sometimes, it's ideal to over-schedule splits for better efficiency. + return min(num_SMs // total_mblocks, max_splits, num_n_blocks) + + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -48,17 +77,42 @@ def _flash_attn_fwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, - softcap: float = 0.0, + softcap: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + learnable_sink: Optional[torch.Tensor] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, m_block_size: int = 128, n_block_size: int = 128, num_threads: int = 384, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + _compute_capability: Optional[int] = None, + score_mod: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + return_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for FlashAttention. + + Args: + ... + score_mod: A callable that takes the attention scores and applies a modification. + mask_mod: A callable that takes token position information and selectively masks + block_sparse_tensors: A tuple of tensors used for block sparsity. + return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate + out: Optional pre-allocated output tensor. If None, will be allocated internally. + lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. + aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. + """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -66,91 +120,411 @@ def _flash_attn_fwd( total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 - seqlen_q = max_seqlen_q + seqlen_q = None total_q = q.shape[0] - seqlen_k, num_head_kv, _ = k.shape[-3:] + if page_table is not None: + assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" + assert page_table.dtype == torch.int32, "page_table must be int32" + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" + max_num_pages_per_seq = page_table.shape[1] + assert page_table.shape == (batch_size, max_num_pages_per_seq) + num_pages, page_size = k.shape[:2] + seqlen_k = num_pages * page_size + else: + num_pages, page_size = None, None + seqlen_k = k.shape[-3] + num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] if cu_seqlens_k is None: - assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) - assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + if page_table is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) + if cu_seqlens_q is not None: - assert max_seqlen_q is not None, "max_seqlen_q must be provided if cu_seqlens_q is provided" - assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" - assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" - assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + assert seqused_q is None or seqused_q.shape == (batch_size,), ( + "seqused_q must have shape (batch_size,)" + ) + assert seqused_k is None or seqused_k.shape == (batch_size,), ( + "seqused_k must have shape (batch_size,)" + ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: - assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" - assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)), "inputs must be on CUDA device" + assert t.dtype == torch.int32, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + ) + assert t.stride(0) == 1, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + ) + if learnable_sink is not None: + assert learnable_sink.shape == (num_head,) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" + + assert all( + t is None or t.is_cuda + for t in ( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + learnable_sink, + ) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" - alignment = 128 // q.element_size() + alignment = 16 // q.element_size() assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) + if softcap == 0.0: + softcap = None qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 out_torch_dtype = q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) - out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + + if out is None: + out = torch.empty( + *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device + ) + else: + expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) + assert out.shape == expected_out_shape, ( + f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" + ) + assert out.dtype == out_torch_dtype, ( + f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" + ) + assert out.device == device, ( + f"out tensor device {out.device} does not match input device {device}" + ) + assert out.is_cuda, "out tensor must be on CUDA device" + + if lse is None: + lse = ( + torch.empty(lse_shape, dtype=torch.float32, device=device) + if requires_grad or return_lse + else None + ) + elif lse is not None: + assert lse.shape == lse_shape, ( + f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" + ) + assert lse.dtype == torch.float32, ( + f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" + ) + assert lse.device == device, ( + f"lse tensor device {lse.device} does not match input device {device}" + ) + assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] - q_tensor, k_tensor, v_tensor, o_tensor = [ - utils.convert_from_dlpack( - t.detach(), leading_dim=t.ndim - 1, divisibility=128 // dtype.width - ) for t in (q, k, v, out) - ] - lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + learnable_sink_tensor, + ) = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] - max_seqlen_q = cutlass.Int32(max_seqlen_q) if max_seqlen_q is not None else None + page_table_tensor = ( + from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) + if page_table is not None + else None + ) + compute_capability = ( + torch.cuda.get_device_capability()[0] + if _compute_capability is None + else _compute_capability + ) + + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + + + sparse_tensors = None + if block_sparse_tensors is not None: + if seqlen_q is None: + raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") + m_block_size_block = m_block_size + if compute_capability == 10: + # TODO: This multiplier should really be q_stage, wire up in later PR + # 1 cta handles 2*tile_m row + m_block_size_block = 2 * m_block_size + expected_m_blocks = (seqlen_q + m_block_size_block - 1) // m_block_size_block + expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size + block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=(batch_size, num_head, expected_m_blocks), + expected_index_shape=(batch_size, num_head, expected_m_blocks, expected_n_blocks), + ) + sparse_tensors = to_cute_block_sparse_tensors(block_sparse_tensors) + + use_block_sparsity = sparse_tensors is not None + + if mask_mod is None: + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + else: + causal, local = False, False + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + if compute_capability == 9: # TODO: tune block size according to hdim. + if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: + n_block_size = 192 + if compute_capability == 10: + # TODO: fix the varlen case + if ( + pack_gqa + and (128 % qhead_per_kvhead != 0) + or (cu_seqlens_q is not None or seqused_q is not None) + ): + pack_gqa = False + # TODO: fix GQA + SplitKV + non-varlen + if pack_gqa and num_splits != 1 and cu_seqlens_q is None: + pack_gqa = False + + if num_splits < 1: + max_seqlen_k = seqlen_k if cu_seqlens_k is None else (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + max_seqlen_q = seqlen_q if cu_seqlens_q is None else (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead + seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) + num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + num_m_blocks = (seqlen_q_packgqa + m_block_size - 1) // m_block_size + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_splits = num_splits_heuristic( + total_mblocks, + torch.cuda.get_device_properties(device).multi_processor_count, + num_n_blocks, + 128, + ) + + is_split_kv = num_splits > 1 + if is_split_kv: + out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) + lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) + + q_tensor, k_tensor, v_tensor, o_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out if not is_split_kv else out_partial) + ] + if is_split_kv: + lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 1) + elif lse is not None: + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + else: + lse_tensor = None + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False + + if softcap is not None: + assert score_mod is None, "softcap and score_mod cannot be used together" + score_mod = utils.create_softcap_scoremod(softcap) + + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) + if score_mod is not None: + if is_varlen: + raise NotImplementedError( + "score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + + if mask_mod is not None: + if is_varlen: + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + if pack_gqa: + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + + if use_block_sparsity: + if is_varlen: + raise NotImplementedError( + "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + if pack_gqa: + raise NotImplementedError( + "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + if is_split_kv: + raise NotImplementedError( + "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." + ) + + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [from_dlpack(buf).mark_layout_dynamic() for buf in aux_tensors] + compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, - cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, - m_block_size, n_block_size, num_threads + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + score_mod_hash, + mask_mod_hash, + use_block_sparsity, + len(aux_tensors) if aux_tensors is not None else 0, + lse is None, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + page_table is not None, + window_size_left is not None, + window_size_right is not None, + learnable_sink is not None, + m_block_size, + n_block_size, + num_threads, + is_split_kv, + pack_gqa, + compute_capability, + page_size not in [None, 128], # paged KV non-TMA ) if compile_key not in _flash_attn_fwd.compile_cache: - # fa_fwd = FlashAttentionForwardSm80( - fa_fwd = FlashAttentionForwardSm90( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - is_causal=causal, - has_softcap=softcap != 0.0, - m_block_size=m_block_size, - n_block_size=n_block_size, - # num_stages=1, - num_stages=2, - num_threads=num_threads, - Q_in_regs=False, - ) + if compute_capability == 9: + assert page_table is None, "paged KV not supported on SM 9.0" + assert not is_split_kv, "SplitKV not supported on SM 9.0" + # fa_fwd = FlashAttentionForwardSm80( + fa_fwd = FlashAttentionForwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=m_block_size, + tile_n=n_block_size, + # num_stages=1, + num_stages=2, + num_threads=num_threads, + Q_in_regs=False, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod, + score_mod=score_mod, + has_aux_tensors=aux_tensors is not None, + ) + elif compute_capability == 10: + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, 128], + is_varlen_q=cu_seqlens_q is not None + or seqused_q is not None, + ) + else: + raise ValueError( + f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" + ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softmax_scale, softcap, current_stream + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, ) _flash_attn_fwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softmax_scale, softcap, current_stream + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, ) + if is_split_kv: + _flash_attn_fwd_combine( + out_partial, + lse_partial.transpose(-1, -2), + out, + lse.transpose(-1, -2) if lse is not None else None, + cu_seqlens_q, + seqused_q, + ) return out, lse @@ -167,9 +541,12 @@ def _flash_attn_bwd( softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, + pack_gqa: bool = False, num_stages_Q: int = 2, num_stages_dO: int = 2, SdP_swapAB: bool = False, @@ -179,84 +556,315 @@ def _flash_attn_bwd( AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 2, V_in_regs: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + deterministic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, out, dout, lse = [maybe_contiguous(t) for t in (q, k, v, out, dout, lse)] - batch_size, seqlen_q, num_head, head_dim = q.shape - _, seqlen_k, num_head_kv, _ = k.shape - _, _, _, head_dim_v = v.shape - assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) - assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) - assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + compute_capability = torch.cuda.get_device_capability()[0] + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + + if compute_capability == 9: + m_block_size = 80 if not causal else 64 + n_block_size = 128 + num_stages_Q = 2 + num_stages_dO = 2 + num_stages_PdS = 2 + SdP_swapAB = True + dKV_swapAB = False + dQ_swapAB = not causal + AtomLayoutMSdP = 1 + AtomLayoutNdKV = 2 + AtomLayoutMdQ = 1 + cluster_size = 1 + assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" + else: + m_block_size = 128 + n_block_size = 128 + dQ_swapAB = False + dKV_swapAB = False + AtomLayoutMdQ = 1 + AtomLayoutNdKV = 1 + # TODO: support cluster size 2 + cluster_size = 1 + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ + maybe_contiguous(t) + for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + + if cu_seqlens_k is None: + batch_size, seqlen_k = k.shape[:2] + total_k = batch_size * seqlen_k + else: + batch_size = cu_seqlens_k.shape[0] - 1 + seqlen_k = None + total_k = k.shape[0] + + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if local: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + + if cu_seqlens_k is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (total_k, num_head_kv, head_dim) + assert v.shape == (total_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + + assert out.shape == (total_q, num_head, head_dim_v) + assert dout.shape == (total_q, num_head, head_dim_v) + assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" + else: + assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert lse.shape == (batch_size, num_head, seqlen_q), ( + "lse must have shape (batch_size, num_head, seqlen_q)" + ) + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" - assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, "inputs must have the same dtype" + assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( + "inputs must have the same dtype" + ) + for t in [cu_seqlens_q, cu_seqlens_k]: + if t is not None: + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all(t.is_cuda for t in (q, k, v, out, dout, lse)), "inputs must be on CUDA device" + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" - alignment = 128 // q.element_size() + alignment = 16 // q.element_size() assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + if compute_capability == 10: + pack_gqa = False # override for now + if compute_capability != 10: + assert deterministic is False, "bwd deterministic only supported for sm100 for now" device = q.device # TODO: check if this is the right rounding - seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) - dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) - lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + + head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 + + if cu_seqlens_q is None: + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + dq_accum = torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dpsum = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + lse_log2 = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + else: + total_q_rounded_padded = ( + (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size + ) + dq_accum = torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) + dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + if qhead_per_kvhead > 1: - seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 - dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + if cu_seqlens_k is None: + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + num_n_blocks = seqlen_k_rounded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + seqlen_k_rounded = seqlen_k_rounded + n_block_size + dk_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) + else: + total_k_rounded_padded = ( + (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + ) + num_n_blocks = total_k_rounded_padded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + total_k_rounded_padded = total_k_rounded_padded + n_block_size + dk_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ - utils.convert_from_dlpack( - t.detach(), leading_dim=3, divisibility=128 // dtype.width - ) for t in (q, k, v, out, dout, dq, dk, dv) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = utils.convert_from_dlpack(lse.detach(), leading_dim=2, alignment=4) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse.ndim - 1 + ) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - utils.convert_from_dlpack(t.detach(), leading_dim=2, divisibility=128 // cutlass.Float32.width) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dq_accum, dpsum, lse_log2) ] if qhead_per_kvhead > 1: dk_accum_tensor, dv_accum_tensor = [ - utils.convert_from_dlpack(t.detach(), leading_dim=2, divisibility=128 // cutlass.Float32.width) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dk_accum, dv_accum) ] + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1) + if t is not None + else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + if deterministic: + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") + else: + dQ_semaphore = None + + if deterministic and qhead_per_kvhead > 1: + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + else: + dK_semaphore = None + dV_semaphore = None + dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ + utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) + if t is not None else None + for t in (dQ_semaphore, dK_semaphore, dV_semaphore) + ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. - compile_key_pre = (dtype, head_dim_v, m_block_size, num_threads) + compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: fa_bwd_pre = FlashAttentionBackwardPreprocess( - dtype, head_dim_v, m_block_size, num_threads=num_threads, + dtype, + head_dim_v, + m_block_size, + num_threads=num_threads, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( - fa_bwd_pre, o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, - dq_accum_tensor, current_stream + fa_bwd_pre, + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, current_stream + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) # Backward kernel: compute dk, dv, dq_accum. - compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, m_block_size, - n_block_size, num_threads, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, - AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs - ) + if compute_capability == 9: + compile_key = ( + compute_capability, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + num_stages_Q, + num_stages_dO, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs, + ) + else: + compile_key = ( + compute_capability, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + window_size_left is not None, + window_size_right is not None, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + cluster_size, + deterministic, + ) + num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: fa_bwd_sm80 = FlashAttentionBackwardSm80( dtype, @@ -268,6 +876,7 @@ def _flash_attn_bwd( num_stages_Q, num_stages_dO, num_threads, + pack_gqa, causal, SdP_swapAB, dKV_swapAB, @@ -277,34 +886,112 @@ def _flash_attn_bwd( AtomLayoutMdQ, V_in_regs=V_in_regs, ) + if compute_capability == 9: + fa_bwd_obj = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_stages_PdS, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + num_threads, + V_in_regs=V_in_regs, + ) + else: + fa_bwd_obj = FlashAttentionBackwardSm100( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + # tile_m=m_block_size, + # tile_n=n_block_size, + cluster_size=cluster_size, + # cluster_size=1, + deterministic=deterministic, + ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( - fa_bwd_sm80, q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + fa_bwd_obj, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, - softmax_scale, current_stream + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, - softmax_scale, current_stream + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) + num_threads = 256 if compute_capability == 9 else 128 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: + arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB + dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, current_stream + fa_bwd_post, + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum_tensor, dq_tensor, softmax_scale, current_stream + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) if qhead_per_kvhead > 1: @@ -316,22 +1003,51 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, current_stream + fa_bwd_post, + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum_tensor, dk_tensor, softmax_scale, current_stream + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, + ) + compile_key_post = ( + dtype, + head_dim_v, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, ) - compile_key_post = (dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + fa_bwd_post, + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) return dq, dk, dv @@ -343,7 +1059,6 @@ def _flash_attn_bwd( class FlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -352,20 +1067,48 @@ def forward( v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): + # Only create block sparse tensors if at least one block sparse parameter is provided + block_sparse_tensors = None + if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): + block_sparse_tensors = BlockSparseTensorsTorch( + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, + ) out, lse = _flash_attn_fwd( q, k, v, softmax_scale=softmax_scale, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod @@ -381,12 +1124,14 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], + deterministic=ctx.deterministic, ) - return dq, dk, dv, *((None,) * 3) + return dq, dk, dv, *((None,) * 20) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -395,12 +1140,17 @@ def forward( v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): out, lse = _flash_attn_fwd( q, @@ -410,25 +1160,48 @@ def forward( cu_seqlens_k, seqused_q, seqused_k, - max_seqlen_q, + page_table=page_table, softmax_scale=softmax_scale, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) - ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod def backward(ctx, dout, *args): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - raise NotImplementedError( - "Backward pass for FlashAttention with variable length sequences is not implemented yet." + assert seqused_q == seqused_k == None + assert ctx.softcap == 0.0 + dq, dk, dv = _flash_attn_bwd( + q, + k, + v, + out, + dout, + lse, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + deterministic=ctx.deterministic, ) + return dq, dk, dv, *((None,) * 20) + def flash_attn_func( q: torch.Tensor, @@ -436,7 +1209,17 @@ def flash_attn_func( v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): return FlashAttnFunc.apply( q, @@ -444,7 +1227,17 @@ def flash_attn_func( v, softmax_scale, causal, + window_size, + learnable_sink, softcap, + num_splits, + pack_gqa, + deterministic, + mask_mod, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, ) @@ -456,10 +1249,15 @@ def flash_attn_varlen_func( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): return FlashAttnVarlenFunc.apply( q, @@ -469,8 +1267,288 @@ def flash_attn_varlen_func( cu_seqlens_k, seqused_q, seqused_k, - max_seqlen_q, + page_table, softmax_scale, causal, + window_size, + learnable_sink, softcap, + num_splits, + pack_gqa, + deterministic, + ) + + +def _flash_attn_fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: torch.Tensor, + lse: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + semaphore_to_reset: Optional[torch.Tensor] = None, +) -> None: + """Forward combine kernel for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. + + Args: + out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or + (num_splits, total_q, nheads, headdim) if there's cu_seqlens + lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or + (num_splits, total_q, nheads) if there's cu_seqlens + out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens + lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + num_splits_dynamic_ptr: Dynamic number of splits per batch + semaphore_to_reset: Semaphore for synchronization + k_block_size: Block size for head dimension + + Returns: + None + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + "out_partial must be fp16, bf16, or fp32" + ) + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" + assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" + assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" + assert lse_partial.shape == out_partial.shape[:-1] + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + # Validate output tensor shapes and types + assert out.shape == out_partial.shape[1:], "out shape mismatch" + if lse is not None: + assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" + assert lse.dtype == torch.float32, "lse must be fp32" + + # Validate optional tensors + for t, name in [ + (cu_seqlens, "cu_seqlens"), + (seqused, "seqused"), + (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), + ]: + if t is not None: + assert t.dtype == torch.int32, f"{name} must be int32" + assert t.is_cuda, f"{name} must be on CUDA device" + assert t.is_contiguous(), f"{name} must be contiguous" + + head_dim = out_partial.shape[-1] + num_splits = out_partial.shape[0] + assert num_splits <= 256 + # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + # so that kBlockM is smaller and we have more parallelism. + k_block_size = 64 if head_dim <= 64 else 128 + # We want kBlockM to be as small as possible to maximize parallelism. + # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + log_max_splits = max(math.ceil(math.log2(num_splits)), 4) + if m_block_size == 8: + # If kBlockM == 8 then the minimum number of splits is 32. + # TODO: we can deal w this by using 128 threads instead + log_max_splits = max(log_max_splits, 5) + + # Convert to cute tensors (using kernel-formatted tensors) + out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=4 if not is_varlen else 3 ) + lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse_partial.ndim - 2 + ) + out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3 if not is_varlen else 2) + lse_tensor = ( + from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) + if lse is not None + else None + ) + + optional_tensors = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None + for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) + ] + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( + optional_tensors + ) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Create combine kernel configuration + dtype = torch2cute_dtype_map[out.dtype] + dtype_partial = torch2cute_dtype_map[out_partial.dtype] + + compile_key = ( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + cu_seqlens is not None, + seqused is not None, + lse is not None, + ) + + if compile_key not in _flash_attn_fwd_combine.compile_cache: + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + m_block_size=m_block_size, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + + # Check if implementation is supported + if not fa_combine.can_implement( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads=256, + ): + raise RuntimeError( + "FlashAttention combine kernel cannot be implemented with given parameters" + ) + + _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( + fa_combine, + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream, + ) + + _flash_attn_fwd_combine.compile_cache[compile_key]( + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream, + ) + + +_flash_attn_fwd_combine.compile_cache = {} + + +def flash_attn_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + return_lse: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Flash Attention combine function for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. This is the main user-facing + interface for the combine kernel. + + Args: + out_partial: Partial outputs tensor with shape: + - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input + - (num_splits, total_q, num_heads, head_size) for variable length input + lse_partial: Partial LSE tensor with shape: + - (num_splits, batch_size, seqlen, num_heads) for regular batched input + - (num_splits, total_q, num_heads) for variable length input + out: Optional output tensor. If None, will be created automatically. + out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + return_lse: Whether to return the combined LSE tensor. Default is True. + + Returns: + Tuple of (out, lse) where: + - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) + or (total_q, num_heads, head_size) for varlen + - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) + or (total_q, num_heads) for varlen. None if return_lse=False + + Note: + This function expects the input tensors to be in the format produced by + split attention computation, where the first dimension is num_splits. + The permuting from user format to kernel format is now done inside the kernel. + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + if is_varlen: + # Variable length: (num_splits, total_q, num_heads, head_size) + num_splits, total_q, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, total_q, num_heads), ( + "lse_partial shape mismatch for varlen" + ) + batch_size = 1 # Treat as single batch for varlen + seqlen = total_q + else: + # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) + num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( + "lse_partial shape mismatch" + ) + + # Determine output dtype + if out_dtype is None: + out_dtype = out_partial.dtype + + # Create output if not provided + device = out_partial.device + if out is None: + if is_varlen: + out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) + else: + out = torch.empty( + batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device + ) + + # Create lse output only if requested + if return_lse: + if is_varlen: + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( + 0, 1 + ) + else: + lse = torch.empty( + batch_size, num_heads, seqlen, dtype=torch.float32, device=device + ).transpose(1, 2) + else: + lse = None + + _flash_attn_fwd_combine( + out_partial, + lse_partial, + out, + lse, + cu_seqlens, + seqused, + ) + return out, lse diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index eb3770deea8..430c7d26fc5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,71 +1,496 @@ # Copyright (c) 2025, Tri Dao. +from typing import Optional, Callable +from dataclasses import dataclass + import cutlass import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr import flash_attn.cute.utils as utils -class AttentionMask: +@cute.jit +def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using. + # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., + # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + if const_expr(arch == 90): + col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) + else: + col_limit_transformed = col_limit + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + # Don't need to clamp to 32 since the shr.u32 instruction does that already + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << col_limit_right_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + # This is the equivalent of: + # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127 + # or 0, 1, ..., 15, 32, ..., 47, 64, ... + # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # Here we hardcode for the case of 2 warp groups. + num_wg = 2 + row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min( + row_limit_top % (num_rep * num_wg), num_rep + ) + ncol = cute.size(X.shape) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + row_limit_top_s = max(row_limit_top_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << row_limit_top_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + out_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + X[c] = -Float32.inf if out_bound else X[c] + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx == 128: + # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - n_block_size: cutlass.Constexpr[int], - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # only pass in if we're doing PackGQA - ): - self.m_block_size = m_block_size - self.n_block_size = n_block_size - self.seqlen_q = seqlen_q - self.seqlen_k = seqlen_k - self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_q: Int32 + seqlen_k: Int32 + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA + swap_AB: cutlass.Constexpr[bool] = False @cute.jit def apply_mask( self, acc_S: cute.Tensor, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, m_block: cutlass.Int32, n_block: cutlass.Int32, thr_mma: cute.TiledMma, - mask_seqlen: cutlass.Constexpr, - mask_causal: cutlass.Constexpr, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ) -> None: - acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) - tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB) # We use t0ScS as these indices are known at compile time. We then must subtract the # column limit by the thread column offset. - t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) - thr_col_offset = tScS_mn[0][1] - seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset - if not mask_causal: - if mask_seqlen: - # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): - if t0ScS_mn[0, c][1] >= seqlenk_col_limit: - acc_S_mn[None, c].fill(-cutlass.Float32.inf) - else: # Causal - # If PackGQA, we split the work of compute divmod among threads in the same row - threads_per_row = thr_mma.tv_layout_C.shape[0][0] - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): - assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" - assert cute.size(acc_S_mn.shape[0]) <= threads_per_row - tidx = thr_mma.thr_idx - mma_m_idx = (m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0]) // self.qhead_per_kvhead_packgqa - causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset - for r in range(cute.size(tScS_mn.shape[0])): - # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. - if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + t0ScS_mn = utils.make_acc_tensor_mn_view( + thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB + ) + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + thr_col_offset = tScS_mn[0][COL] + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + if const_expr(not mask_causal and not mask_local and mask_mod is None): + if const_expr(mask_seqlen): + # The compiler now choses not to use R2P + r2p = const_expr(False and not self.swap_AB) + if const_expr(not r2p): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] + else: + mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + + elif const_expr( + not mask_causal and not mask_local and mask_mod is not None + ): # FlexAttention mask mod + nrow = const_expr(cute.size(tScS_mn.shape[0])) + ncol = const_expr(cute.size(tScS_mn.shape[1])) + thr_col_offset = tScS_mn[0, 0][1] + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) + + for r in cutlass.range_constexpr(nrow): + global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + row_for_mod = global_row_idx + if const_expr(wrap_aux_indices): + _, row_for_mod = divmod(global_row_idx, fastdiv_mods[0]) + + for col in cutlass.range_constexpr(ncol): + col_idx_local = t0ScS_mn[0, col][1] + # Convert to absolute column index + global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n + col_for_mod = global_col_idx + if const_expr(wrap_aux_indices): + _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) + + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + q_idx_ssa, + kv_idx_ssa, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + if const_expr(mask_seqlen): + out_of_bounds = (global_row_idx >= self.seqlen_q) or ( + global_col_idx >= self.seqlen_k + ) + if out_of_bounds: + acc_S_mn[r, col] = -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + + else: # Causal or local + if const_expr(not self.swap_AB): + # If PackGQA, we split the work of compute divmod among threads in the same row + threads_per_row = thr_mma.tv_layout_C.shape[0][0] + mma_m_idx = None + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert not self.swap_AB, "swap_AB with PackGQA not supported yet" + assert cute.arch.WARP_SIZE % threads_per_row == 0, ( + "threads_per_row must divide WARP_SIZE" + ) + assert cute.size(acc_S_mn.shape[0]) <= threads_per_row + tidx = thr_mma.thr_idx + mma_m_idx = ( + m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset + ) + if const_expr(mask_causal): + r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100 + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + if const_expr(not r2p): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + acc_S_mn[r, c] = ( + -Float32.inf + if t0ScS_mn[0, c][1] >= col_limit_right + else acc_S_mn[r, c] + ) + else: + mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True) + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if const_expr(self.window_size_left is not None) + else None + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + else: + col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + col_limit_left = ( + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -Float32.inf + else: # swap_AB + assert self.qhead_per_kvhead_packgqa == 1 + thr_row_offset = tScS_mn[0][ROW] + causal_row_offset = ( + seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset + ) + if const_expr(mask_causal): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m + if col0 >= seqlenk_col_limit and mask_seqlen + else col0 - causal_row_offset + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = ( + -Float32.inf + if t0ScS_mn[r, 0][ROW] < row_limit_top + else acc_S_mn[r, c] + ) else: - row_idx = utils.shuffle_sync(mma_m_idx, r % threads_per_row, width=threads_per_row) + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m + if col0 >= seqlenk_col_limit + else col0 - causal_row_offset - self.window_size_right + ) + # TODO: do we need col_limit_sink? + row_limit_bot = col0 - causal_row_offset + self.window_size_left + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + row_idx = t0ScS_mn[r, 0][ROW] + acc_S_mn[r, c] = ( + -Float32.inf + if row_idx < row_limit_top or row_idx > row_limit_bot + else acc_S_mn[r, c] + ) + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + m_block: Int32, + n_block: Int32, + thr_mma: cute.TiledMma, + thr_tmem_load: cute.TiledCopy, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) + tScS = thr_mma.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n + r2p = True + if const_expr(not mask_causal and not mask_local and mask_mod is None): + if const_expr(mask_seqlen): + if const_expr(not r2p): + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -Float32.inf + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + else: + mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case w/ mask_mod + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + row_coord_first = tScS_t2r[0][0] + global_row = row_coord_first + m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa != 1): + mask_row = global_row // self.qhead_per_kvhead_packgqa + else: + mask_row = global_row + mask_row_for_mod = mask_row + if const_expr(wrap_aux_indices): + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_col = col_coord + n_block * self.tile_n + global_col_for_mod = global_col + if const_expr(wrap_aux_indices): + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) + kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + out_of_bounds = (global_row >= self.seqlen_q) or (global_col >= self.seqlen_k) + acc_S[i] = -Float32.inf if out_of_bounds else acc_S[i] + + else: # Causal or local + causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q + row_idx = tScS_t2r[0][0] + m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa != 1): + row_idx = row_idx // self.qhead_per_kvhead_packgqa + if const_expr(mask_causal): col_limit_right = row_idx + causal_row_offset - if cutlass.const_expr(mask_seqlen): + if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): - # only consider the column index, so the row index sets to 0. - if t0ScS_mn[0, c][1] >= col_limit_right: - acc_S_mn[r, c] = -cutlass.Float32.inf + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + ncol = const_expr(cute.size(tScS_t2r.shape)) + if const_expr(not r2p): + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + else: + mask_r2p(acc_S, col_limit_right, arch=100, rank1=True) + else: + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if const_expr(self.window_size_left is not None) + else None + ) + if const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + else: + col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + col_limit_left = ( + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 + ) + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + col_idx = tScS_t2r[i][1] + acc_S[i] = ( + -Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, + ) -> None: + """ + Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. + """ + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + assert t0ScS_t2r[0][COL] == 0, "col0 == 0" + thr_col_offset = tScS_t2r[0][COL] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + if const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + else: # Causal or local + thr_row_offset = tScS_t2r[0][ROW] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + causal_offset = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_causal): + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx < 32: + # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1]) + row_limit_top = causal_offset + if const_expr(mask_seqlen): + # If col is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + r2p = True + if const_expr(not r2p): + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i] + ) + else: + num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 + mask_r2p_transposed(acc_S, row_limit_top, num_rep) + else: + if const_expr(self.window_size_right is not None): + row_limit_top = causal_offset - self.window_size_right + else: + row_limit_top = 0 + if const_expr(self.window_size_left is not None): + row_limit_bot = causal_offset + self.window_size_left + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py new file mode 100644 index 00000000000..546adf17f37 --- /dev/null +++ b/flash_attn/cute/mask_definitions.py @@ -0,0 +1,288 @@ +from typing import Callable, Optional + +import random +import math + +import cutlass +import cutlass.cute as cute +import torch + +from flash_attn.cute import utils + + +MaskModCallable = Optional[ + Callable[ + [ + "cute.TensorSSA", + "cute.TensorSSA", + "cute.TensorSSA", + "cute.TensorSSA", + "Optional[list]", + ], + "cute.TensorSSA", + ] +] + + +# Flex Attention mask functions (PyTorch signatures for reference implementation) +def get_flex_causal_mask(offset: int): + def _flex_causal_mask(b, h, q_idx, kv_idx): + return kv_idx <= q_idx + offset + + return _flex_causal_mask + + +def get_flex_block_causal_mask(offset: int): + def _flex_block_causal_mask(b, h, q_idx, kv_idx): + return kv_idx <= q_idx + offset + + return _flex_block_causal_mask + + +def get_flex_sliding_window_mask(window_left: int, window_right: int, offset: int): + def _flex_sliding_window_mask(b, h, q_idx, kv_idx): + center = q_idx + offset + lower = center - window_left + upper = center + window_right + return (kv_idx >= lower) & (kv_idx <= upper) + + return _flex_sliding_window_mask + + +def flex_block_diagonal_mask(b, h, q_idx, kv_idx): + block_size = 64 + return (q_idx // block_size) == (kv_idx // block_size) + + +def flex_mini_causal_mask(b, h, q_idx, kv_idx): + return (q_idx % 128) >= (kv_idx % 128) + + +def flex_document_mask(b, h, q_idx, kv_idx, doc_id): + return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + + +# CuTe versions for kernel compilation +def get_cute_causal_mask(offset: int): + @cute.jit + def _cute_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors: None, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) + + return _cute_causal_mask + + +def get_cute_block_causal_mask(offset: int): + @cute.jit + def _cute_block_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors: None, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) + + return _cute_block_causal_mask + + +def get_cute_sliding_window_mask(window_left: int, window_right: int, offset: int): + @cute.jit + def _cute_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + window_left_ssa = utils.scalar_to_ssa(window_left, cutlass.Int32) + window_right_ssa = utils.scalar_to_ssa(window_right, cutlass.Int32) + center = m_idx + offset_ssa + lower = center - window_left_ssa + upper = center + window_right_ssa + return (n_idx >= lower) & (n_idx <= upper) + + return _cute_sliding_window_mask + + +@cute.jit +def cute_document_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors: list, +) -> cute.TensorSSA: + doc_id = aux_tensors[0] + m_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], m_idx[0]], cutlass.Int32) + n_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], n_idx[0]], cutlass.Int32) + return m_doc == n_doc + + +@cute.jit +def cute_block_diagonal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + block_size_ssa = utils.scalar_to_ssa(64, cutlass.Int32) + return (m_idx // block_size_ssa) == (n_idx // block_size_ssa) + + +@cute.jit +def cute_mini_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + tile_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) + m_mod = m_idx % tile_size_ssa + n_mod = n_idx % tile_size_ssa + return m_mod >= n_mod + + +@cute.jit +def cute_prefix_lm_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32) + both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa) + causal_part = m_idx >= n_idx + return both_in_prefix | causal_part + + +def flex_prefix_lm_mask(b, h, q_idx, kv_idx): + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size = 512 + both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size) + causal_part = q_idx >= kv_idx + return both_in_prefix | causal_part + + +@cute.jit +def cute_dilated_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Dilated sliding window: every other position in a 256-position window.""" + window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32) + dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32) + in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa) + dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32) + return in_window & dilated + + +def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): + """Dilated sliding window: every other position in a 256-position window.""" + window_size = 256 + dilation = 2 + in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size) + dilated = ((q_idx - kv_idx) % dilation) == 0 + return in_window & dilated + + +def flex_ima_mask(b, h, q_idx, kv_idx, bias): + return kv_idx >= bias[kv_idx] + + +@cute.jit +def cute_ima_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + bias = aux_tensors[0] + threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32) + return n_idx >= threshold + + +def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): + doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + for b in range(batch): + for h in range(nheads): + N = seqlen_q + max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) + n = random.randint(1, max_segments) + n = min(n, N) + cuts = sorted(random.sample(range(1, N), n - 1)) + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] + + doc_ids = [] + for i, length in enumerate(lengths): + doc_ids += [i for _ in range(length)] + + doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) + return doc_ids_tensor + + +STATIC_MASKS = { + "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), + "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), + "dilated_sliding_window": (cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask), + "document": (cute_document_mask, flex_document_mask), + "ima": (cute_ima_mask, flex_ima_mask), +} + +PARAMETERIZED_MASK_FACTORIES = { + "causal": (get_cute_causal_mask, get_flex_causal_mask), + "block_causal": (get_cute_block_causal_mask, get_flex_block_causal_mask), + "sliding_window": (get_cute_sliding_window_mask, get_flex_sliding_window_mask), +} + + +def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=None): + """Get (cute_mask, flex_mask) pair for the given mask name. + + For static masks, seqlen info is not needed. + For parameterized masks, seqlen_q and seqlen_k are required. + """ + if mask_name in STATIC_MASKS: + return STATIC_MASKS[mask_name] + + if mask_name not in PARAMETERIZED_MASK_FACTORIES: + raise ValueError(f"Unknown mask: {mask_name}") + + if seqlen_q is None or seqlen_k is None: + raise ValueError(f"Parameterized mask '{mask_name}' requires seqlen_q and seqlen_k") + + cute_factory, flex_factory = PARAMETERIZED_MASK_FACTORIES[mask_name] + offset = seqlen_k - seqlen_q + + if mask_name == "sliding_window": + if window_size is None: + raise ValueError("sliding_window mask requires window_size parameter") + cute_mask = cute_factory(window_size, window_size, offset) + flex_mask = flex_factory(window_size, window_size, offset) + else: + cute_mask = cute_factory(offset) + flex_mask = flex_factory(offset) + + return cute_mask, flex_mask + + +if __name__ == "__main__": + doc_ids = random_doc_id_tensor(1, 2, 128) + print(f"{doc_ids = }") diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py new file mode 100644 index 00000000000..16336c34686 --- /dev/null +++ b/flash_attn/cute/mma_sm100_desc.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Tri Dao. +# Ported Cutlass code from C++ to Python: +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type → encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them + if cutlass_type is cutlass.FloatE4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.FloatE5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for Blackwell MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ + # Swizzle string has the form "S" + swz_str = str(swizzle) + inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' + B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 99a76222bce..777c44079a0 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -10,3 +10,22 @@ class NamedBarrierFwd(enum.IntEnum): WarpSchedulerWG3 = enum.auto() PFull = enum.auto() PEmpty = enum.auto() + + +class NamedBarrierBwd(enum.IntEnum): + Epilogue = enum.auto() + WarpSchedulerWG1 = enum.auto() + WarpSchedulerWG2 = enum.auto() + WarpSchedulerWG3 = enum.auto() + PdS = enum.auto() + dQFullWG0 = enum.auto() + dQFullWG1 = enum.auto() + dQEmptyWG0 = enum.auto() + dQEmptyWG1 = enum.auto() + + +class NamedBarrierBwdSm100(enum.IntEnum): + EpilogueWG1 = enum.auto() + EpilogueWG2 = enum.auto() + Compute = enum.auto() + dQaccReduce = enum.auto() diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index a2dafa73c2f..765e71307ad 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -1,7 +1,5 @@ # Copyright (c) 2025, Tri Dao. -import math -import operator import cutlass import cutlass.cute as cute @@ -10,7 +8,6 @@ class PackGQA: - def __init__( self, m_block_size: cutlass.Constexpr[int], @@ -64,24 +61,27 @@ def load_Q( assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): q_ptr_i64 = utils.shuffle_sync( tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) q_gmem_ptr = cute.make_ptr( mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) - if t0QcQ[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]: + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tQsQ.shape[0][0]) mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) - for k in range(cute.size(tQsQ.shape[2])): + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): ki = tQcQ[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, mQ_cur_copy[None, ki], tQsQ[None, m, k], - pred=tQpQ[None, m, k] if self.check_hdim_oob else None, + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @@ -105,9 +105,11 @@ def store_LSE( assert cute.size(tLSErLSE) <= threads_per_row num_threads = tiled_mma.size tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tLSErLSE)): + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): lse_ptr_i64 = utils.shuffle_sync( - tPrLSEPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row, + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, ) lse_gmem_ptr = cute.make_ptr( mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 @@ -138,22 +140,25 @@ def store_O( assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tOrO.shape[1])): + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): o_ptr_i64 = utils.shuffle_sync( tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) o_gmem_ptr = cute.make_ptr( mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) - if t0OcO[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]: + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tOrO.shape[0][0]) mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) - for k in range(cute.size(tOrO.shape[2])): + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): ki = tOcO[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, tOrO[None, m, k], mO_cur_copy[None, ki], - pred=tOpO[None, m, k] if self.check_hdim_oob else None, + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py new file mode 100644 index 00000000000..8b0949d1404 --- /dev/null +++ b/flash_attn/cute/paged_kv.py @@ -0,0 +1,188 @@ +from typing import Type +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import ParamsBase +from cutlass.cute import FastDivmodDivisor + + +@dataclass +class PagedKVManager(ParamsBase): + mPageTable: cute.Tensor + mK_paged: cute.Tensor + mV_paged: cute.Tensor + thread_idx: Int32 + + page_size_divmod: FastDivmodDivisor + seqlen_k: Int32 + leftpad_k: Int32 + n_block_size: Int32 + num_threads: cutlass.Constexpr[Int32] + head_dim_padded: cutlass.Constexpr[Int32] + head_dim_v_padded: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + page_entry_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + tPrPage: cute.Tensor + tPrPageOffset: cute.Tensor + tKpK: cute.Tensor + tVpV: cute.Tensor + + @staticmethod + def create( + mPageTable: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + page_size_divmod: FastDivmodDivisor, + bidb: Int32, + bidh: Int32, + thread_idx: Int32, + seqlen_k: Int32, + leftpad_k: Int32, + n_block_size: cutlass.Constexpr[Int32], + head_dim_padded: cutlass.Constexpr[Int32], + head_dim_v_padded: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + ): + universal_copy_bits = 128 + gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line + async_copy_elems = universal_copy_bits // dtype.width + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads + + tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + + mPageTable = mPageTable[bidb, None] + mK_paged = mK_paged[None, None, bidh, None] + mV_paged = mV_paged[None, None, bidh, None] + + cK = cute.make_identity_tensor((n_block_size, head_dim_padded)) + tKcK = gmem_thr_copy_KV.partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1]) + + if const_expr(head_dim_padded == head_dim_v_padded): + tVpV = tKpK + else: + cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) + tVcV = gmem_thr_copy_KV.partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + + return PagedKVManager( + mPageTable, + mK_paged, + mV_paged, + thread_idx, + page_size_divmod, + seqlen_k, + leftpad_k, + n_block_size, + num_threads, + head_dim_padded, + head_dim_v_padded, + gmem_threads_per_row, + page_entry_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + tPrPage, + tPrPageOffset, + tKpK, + tVpV, + ) + + @cute.jit + def load_page_table(self, n_block: Int32): + for i in cutlass.range(self.page_entry_per_thread, unroll=1): + row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row + row_idx = n_block * self.n_block_size + row + + page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod) + + is_valid = ( + (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size + ) and row_idx < self.seqlen_k + page = self.mPageTable[page_idx] if is_valid else 0 + + self.tPrPage[i] = page + self.tPrPageOffset[i] = page_offset + + @cute.jit + def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): + assert K_or_V in ("K", "V") + + # Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + + if const_expr(K_or_V == "V"): + # Need to transpose V + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + + head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded + cX = cute.make_identity_tensor((self.n_block_size, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0 + for m in cutlass.range(cute.size(tXsX, mode=[1]), unroll=1): + should_load = tXcX[0, m, 0][0] < seqlenk_row_limit + + page = self.tPrPage[m] + page_offset = self.tPrPageOffset[m] + mX_paged_cur = ( + self.mK_paged[page_offset, None, page] + if const_expr(K_or_V == "K") + else self.mV_paged[None, page_offset, page] + ) + mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) + + if should_load: + for k in cutlass.range(cute.size(tXsX, mode=[2]), unroll=1): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy[None, ki], + tXsX[None, m, k], + ) + elif const_expr(K_or_V == "V"): + # Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway. + fill_swizzled(tXsX[None, m, None], 0) + + +@cutlass.dsl_user_op +def fill_swizzled(tensor, value: cutlass.Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + """ + rTmp = cute.make_rmem_tensor_like(tensor, tensor.element_type) + rTmp.fill(value) + cute.autovec_copy(rTmp, tensor) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 3df229c4f3e..7ed7ab06d29 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -6,10 +6,44 @@ import cutlass import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, Int32, if_generate -from cutlass.utils import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait -from cutlass.utils.pipeline import PipelineUserType -from cutlass.utils.pipeline import _PipelineOp +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate +from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup +from cutlass.pipeline import PipelineUserType, PipelineOp +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg + + +# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + assert False, ( + "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + ) class PipelineStateSimple: @@ -38,7 +72,10 @@ def stages(self) -> int: def index(self) -> Int32: # return self._phase_index & 0xFFFF # return self._phase_index & ((1 << self._log_stages) - 1) - return self._phase_index % self._stages + if const_expr(self._stages == 1): + return Int32(0) + else: + return self._phase_index % self._stages @property def phase(self) -> Int32: @@ -47,10 +84,16 @@ def phase(self) -> Int32: # take modulo 2. But in practice just passing the phase in without modulo works fine. # return (self._phase_index >> self._log_stages) % 2 # return self._phase_index >> self._log_stages - return self._phase_index // self._stages + if const_expr(self._stages == 1): + return self._phase_index + else: + return self._phase_index // self._stages def advance(self): - self._phase_index += 1 + if const_expr(self._stages == 1): + self._phase_index ^= 1 + else: + self._phase_index += 1 # def then_body(phase_index): # # XOR the phase bit and set the index to 0 @@ -67,9 +110,6 @@ def advance(self): # [Int32], # ) - def __get_mlir_types__(self): - return [self._phase_index.type] - def __extract_mlir_values__(self): phase_index = self._phase_index return [phase_index.ir_value()] @@ -88,34 +128,34 @@ def make_pipeline_state(type: PipelineUserType, stages: int): elif type is PipelineUserType.Consumer: return PipelineStateSimple(stages, Int32(0)) else: - assert ( - False - ), "Error: invalid PipelineUserType specified for make_pipeline_state." - + assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." @dataclass(frozen=True) -class PipelineTmaAsyncNoCluster(PipelineAsync): - +class PipelineTmaAsync(PipelineTmaAsyncOg): """ - If size(ClusterShape) == 1, PipelineTmaAsync has all threads - signaling the barrier during consumer_release. This causes a perf regression in FA3 - forward pass (especially hdim 128 causal). We instead implement a version of - PipelineTmaAsync where only 1 out of 128 threads signals the barrier. - - Assumptions: - (1) num_consumers % NumThreadsPerWarpGroup == 0 - (2) all 128 threads in the warp group are sync'ed right before calling consumer_release + If size(ClusterShape) == 1, PipelineTmaAsync has all threads + signaling the barrier during consumer_release. This causes a perf regression in FA3 + forward pass (especially hdim 128 causal). We instead implement a version of + PipelineTmaAsync where only 1 out of 128 threads signals the barrier. + + Assumptions: + (1) num_consumers % NumThreadsPerWarpGroup == 0 + (2) all 128 threads in the warp group are sync'ed right before calling consumer_release """ @staticmethod def create( - barrier_storage: cute.Pointer, - num_stages: Int32, + *, + num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - init_wait: bool = True, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), + init_wait: cutlass.Constexpr[bool] = True, ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. @@ -123,59 +163,203 @@ def create( :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] """ - producer_type = _PipelineOp.TmaLoad - consumer_type = _PipelineOp.AsyncThread + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) - sync_object_array_full = PipelineAsync._make_sync_object_array( + + sync_object_full = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8), num_stages, producer, tx_count ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( + sync_object_empty = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer ) - dst_rank = None + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + if const_expr(cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1): + dst_rank = None + is_signalling_thread = tidx % 128 == 0 + else: + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal( + cta_layout_vmnk, tidx, mcast_mode_mn + ) + producer_mask = None - if init_wait: + + if const_expr(init_wait): pipeline_init_wait() - return PipelineTmaAsyncNoCluster( - sync_object_array_full, - sync_object_array_empty, + + return PipelineTmaAsync( + sync_object_full, + sync_object_empty, num_stages, producer_mask, dst_rank, + is_signalling_thread, ) def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, ): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_array_empty.wait(state.index, state.phase), + lambda: self.sync_object_empty.wait(state.index, state.phase), ) - self.sync_object_array_full.arrive(state.index, self.producer_mask) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. - """ - pass + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count) def consumer_release(self, state: PipelineState): """ TMA consumer release conditionally signals the empty buffer to the producer. """ # Only 1 thread per warp group signals the empty buffer. + if self.consumer_mask is None: # No cluster, 1 thread per warp group to signal + if_generate( + cute.arch.thread_idx()[0] % 128 == 0, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + else: + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), + init_wait: cutlass.Constexpr[bool] = True, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask( + cta_layout_vmnk, mcast_mode_mn + ) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + if const_expr(init_wait): + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ if_generate( - cute.arch.thread_idx()[0] % 128 == 0, - lambda: self.sync_object_array_empty.arrive(state.index, self.consumer_mask), + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count), + ) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml new file mode 100644 index 00000000000..8b5942b10d0 --- /dev/null +++ b/flash_attn/cute/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "flash-attn-cute" +version = "0.1.0" +description = "Flash Attention CUTE (CUDA Template Engine) implementation" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "BSD 3-Clause License"} +authors = [ + {name = "Tri Dao"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "nvidia-cutlass-dsl==4.3.0", + "torch", + "einops", + "typing_extensions", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/Dao-AILab/flash-attention" +Repository = "https://github.com/Dao-AILab/flash-attention" + +[tool.setuptools] +packages = ["flash_attn.cute"] +package-dir = {"flash_attn.cute" = "."} + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +ignore = [ + "E731", # do not assign a lambda expression, use a def + "E741", # Do not use variables named 'I', 'O', or 'l' + "F841", # local variable is assigned to but never used +] diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index d14bfb827f9..0851ddd0522 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -1,13 +1,50 @@ from typing import Optional +from dataclasses import dataclass import cutlass import cutlass.cute as cute +from cutlass import Int32, const_expr +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + +@dataclass(frozen=True) class SeqlenInfo: + offset: cutlass.Int32 + seqlen: cutlass.Int32 - def __init__( - self, + @staticmethod + def create( + batch_idx: cutlass.Int32, + seqlen_static: cutlass.Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, seqlen) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: cutlass.Int32 + offset_k: cutlass.Int32 + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + + @staticmethod + def create( batch_idx: cutlass.Int32, seqlen_q_static: cutlass.Int32, seqlen_k_static: cutlass.Int32, @@ -16,13 +53,47 @@ def __init__( mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, ): - self.offset_q = 0 if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] - self.offset_k = 0 if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] - if cutlass.const_expr(mSeqUsedQ is not None): - self.seqlen_q = mSeqUsedQ[batch_idx] + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] else: - self.seqlen_q = seqlen_q_static if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx + 1] - self.offset_q - if cutlass.const_expr(mSeqUsedK is not None): - self.seqlen_k = mSeqUsedK[batch_idx] + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + has_cu_seqlens_q: int = mCuSeqlensQ is not None + has_cu_seqlens_k: int = mCuSeqlensK is not None + return SeqlenInfoQK( + offset_q, offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, has_cu_seqlens_k + ) + + def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset = ( + self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q) + ) + idx = (offset,) + (0,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + + def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] else: - self.seqlen_k = seqlen_k_static if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - self.offset_k + idx = (self.offset_k,) + (0,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index a658d072585..658934ce753 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -2,24 +2,51 @@ import math import operator +from typing import Tuple +from dataclasses import dataclass import cutlass import cutlass.cute as cute +from cutlass import Float32 import flash_attn.cute.utils as utils +from flash_attn.cute.cute_dsl_utils import ParamsBase -class Softmax: +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None - def __init__(self, scale_log2: cutlass.Float32, num_rows: cutlass.Constexpr[int]): - self.scale_log2 = scale_log2 - self.row_max = cute.make_fragment(num_rows, cutlass.Float32) - self.row_sum = cute.make_fragment_like(self.row_max) + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_fragment(num_rows, Float32) + row_sum = cute.make_fragment(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) def reset(self) -> None: - self.row_max.fill(-cutlass.Float32.inf) + self.row_max.fill(-Float32.inf) self.row_sum.fill(0.0) + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + @cute.jit def online_softmax( self, @@ -36,53 +63,82 @@ def online_softmax( """ # Change acc_S to M,N layout view. acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - row_scale = cute.make_fragment_like(self.row_max, cutlass.Float32) + row_scale = cute.make_fragment_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + # Each iteration processes one row of acc_S - for r in range(cute.size(self.row_max)): + for r in cutlass.range(cute.size(row_max), unroll_full=True): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) - row_max_cur = acc_S_row.reduce(cute.ReductionOp.MAX, -cutlass.Float32.inf, 0) + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + if cutlass.const_expr(is_first): - if check_inf: - row_max_cur = 0.0 if row_max_cur == -cutlass.Float32.inf else row_max_cur - row_max_cur_scaled = row_max_cur * self.scale_log2 - acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) - acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) + + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) row_scale[r] = 1.0 else: - row_max_prev = self.row_max[r] - row_max_cur = cute.arch.fmax(row_max_prev, row_max_cur) - if check_inf: - row_max_cur = 0.0 if row_max_cur == -cutlass.Float32.inf else row_max_cur - row_max_cur_scaled = row_max_cur * self.scale_log2 - acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) - acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) + row_max_prev = row_max[r] + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) - row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) - acc_S_row_sum = acc_S_row_sum + self.row_sum[r] * row_scale[r] - self.row_max[r] = row_max_cur - self.row_sum[r] = acc_S_row_sum + row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2) + + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_max[r] = row_max_cur + row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) + return row_scale @cute.jit - def finalize(self, final_scale: cute.Float32 = 1.0) -> cute.Tensor: - """Finalize the online softmax by computing the scale and logsumexp. - """ + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp.""" + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + # quad reduction for row_sum as we didn't do it during each iteration of online softmax - self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) - row_scale = cute.make_fragment_like(self.row_max, cutlass.Float32) - for r in range(cute.size(self.row_sum)): + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_fragment_like(row_max, Float32) + + for r in cutlass.range(cute.size(row_sum), unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2) + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 - acc_O_mn_row_is_zero_or_nan = self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] row_scale[r] = ( - cute.arch.rcp_approx(self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) ) * final_scale - row_sum_cur = self.row_sum[r] + row_sum_cur = row_sum[r] LN2 = math.log(2.0) - self.row_sum[r] = ( - (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + row_sum[r] = ( + (row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf ) return row_scale @@ -96,5 +152,289 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: """ acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) - for r in range(cute.size(row_scale)): + for r in cutlass.range(cute.size(row_scale), unroll_full=True): acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_fragment(num_rows, Float32) + row_sum = cute.make_fragment(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + # tmp = self._compute_row_sum(acc_S_row_exp) + # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + e2e: cutlass.Constexpr[bool] = False, + e2e_freq: cutlass.Constexpr[int] = 16, + e2e_res: cutlass.Constexpr[int] = 4, + e2e_frg_limit: cutlass.Constexpr[int] = 1, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + if cutlass.const_expr(not e2e): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + if cutlass.const_expr( + k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit + ): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + # acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + # (acc_S_row[i], acc_S_row[i + 1]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) + # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + # utils.fma_packed_f32x2( + # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # ) + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + +@cute.jit +def floor_if_packed( + q_idx, + qhead_per_kvhead: cutlass.Constexpr[int], +) -> cute.Tensor: + """Convert q_idx to packed format for Pack-GQA.""" + if cutlass.const_expr(qhead_per_kvhead == 1): + return q_idx + return q_idx // qhead_per_kvhead + + +@cute.jit +def apply_score_mod_inner( + score_tensor, + index_tensor, + score_mod: cutlass.Constexpr, + batch_idx, + head_idx, + softmax_scale, + vec_size: cutlass.Constexpr, + qk_acc_dtype: cutlass.Constexpr, + aux_tensors, + fastdiv_mods, + constant_q_idx: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, +): + """Shared implementation for applying score modification. + + Args: + score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100) + index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100) + score_mod: The score modification function to apply + batch_idx: Batch index + head_idx: Head index + softmax_scale: Scale to apply + vec_size: Vector size for processing elements + qk_acc_dtype: Data type for accumulator + aux_tensors: Optional aux_tensors for FlexAttention + fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping + constant_q_idx: If provided, use this constant for all q_idx values + If None, compute q_idx per-element + qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this + when greater than 1 so score mods see logical heads. + """ + n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) + score_vec = cute.make_fragment(vec_size, qk_acc_dtype) + kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + # SSA values for batch (constant across all elements) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) + + # Handle q_idx based on whether it's constant + q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + # For Pack-GQA with non-constant q_idx, we need per-element head indices + # since a thread my process multiple query head indices + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): + for j in cutlass.range(vec_size, unroll_full=True): + score_vec[j] = score_tensor[i + j] * softmax_scale + + # Extract head offset from packed q_idx for Pack-GQA + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + q_idx_packed = index_tensor[i + j][0] + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_idx_logical = q_idx_packed // qhead_per_kvhead + head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead + head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset + + # If we will do loads we mod, in order to not read OOB + if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): + if cutlass.const_expr(constant_q_idx is None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) + _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) + q_idx_vec[j] = q_idx_wrapped + else: + _, seqlen_k_divmod = fastdiv_mods + + _, kv_idx_wrapped = divmod(index_tensor[i + j][1], seqlen_k_divmod) + kv_idx_vec[j] = kv_idx_wrapped + else: + # No bounds checking - direct indexing + if constant_q_idx is None: + q_idx_vec[j] = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) + kv_idx_vec[j] = index_tensor[i + j][1] + + # Convert to SSA for score_mod call + score_ssa = score_vec.load() + kv_idx_ssa = kv_idx_vec.load() + if cutlass.const_expr(constant_q_idx is None): + q_idx_ssa = q_idx_vec.load() + else: + # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical + q_idx_const = constant_q_idx + q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,)) + + # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_ssa = head_idx_vec.load() + else: + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) + + aux_args = [] + if cutlass.const_expr(aux_tensors is not None): + aux_args = aux_tensors + + post_mod_scores = score_mod( + score_ssa, + batch_idx_ssa, + head_idx_ssa, + q_idx=q_idx_ssa, + kv_idx=kv_idx_ssa, + aux_tensors=aux_args, + ) + + # Write back modified scores + score_vec.store(post_mod_scores) + for j in cutlass.range(vec_size, unroll_full=True): + score_tensor[i + j] = score_vec[j] diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py new file mode 100644 index 00000000000..a23a624d059 --- /dev/null +++ b/flash_attn/cute/testing.py @@ -0,0 +1,418 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + grad_values = grad_output[indices] + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + + if zero_lengths: + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + qv=None, + kvpacked=False, + qkvpacked=False, + query_unused_mask=None, + key_unused_mask=None, +): + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + if window_size[1] is None: + local_mask_left = col_idx > sk + else: + local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk) + return torch.logical_or( + local_mask_left, + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length + ), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = ( + torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + ) + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + learnable_sink - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + if key_padding_mask is not None: + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py new file mode 100644 index 00000000000..ef47cedecdf --- /dev/null +++ b/flash_attn/cute/tile_scheduler.py @@ -0,0 +1,715 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple +from dataclasses import dataclass, fields + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr + +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz +from cutlass.cute import FastDivmodDivisor + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + blk_coord = cute.arch.block_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + return ( + cute.round_up(params.num_block, params.cluster_shape_mn[0]), + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) + + # @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + l2_minor: Int32 + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + is_split_kv: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler.Params": + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # Seems faster if swizzle if a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + l2_minor=Int32(swizzle), + num_block_divmod=FastDivmodDivisor(args.num_block), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + is_split_kv=args.is_split_kv, + ) + + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx, self._split_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTBwdScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_block: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + spt: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTBwdScheduler.Params": + size_l2 = 50 * 1024 * 1024 + size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + size_one_dqaccum_head = 0 + size_one_head = size_one_qdo_head + size_one_dqaccum_head + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 8 + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0]) + return SingleTileLPTBwdScheduler.Params( + total_blocks=(num_block * args.cluster_shape_mn[0]) + * args.num_head + * args.num_batch, + num_block=num_block, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * num_block), + l2_minor_residual_divmod=FastDivmodDivisor( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + cluster_shape_mn=args.cluster_shape_mn, + spt=args.lpt, + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + is_valid = self._tile_idx < params.total_blocks + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + if cutlass.const_expr(params.spt): + block = params.num_block - 1 - block + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + ) + + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx, self._split_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx, self._split_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 3768fa3a9a1..f73f66cfccf 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -1,16 +1,69 @@ # Copyright (c) 2025, Tri Dao. import math -from typing import Type, Callable, Optional +import hashlib +import inspect +import re +from typing import Type, Callable, Optional, Tuple, overload +from functools import partial import cutlass import cutlass.cute as cute +from cutlass import Float32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack +# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default +fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, +) + + +def hash_callable(func: Callable) -> str: + """Hash a callable based on the source code or bytecode and closure values.""" + if hasattr(func, "__wrapped__"): + # cute.jit returns a wrapper whose repr/closure changes per compile; hash the undecorated function. + base_func = func.__wrapped__ + func = base_func + + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for idx, cell in enumerate(func.__closure__): + cell_value = cell.cell_contents + hasher.update(repr(cell_value).encode()) + + return hasher.hexdigest() + + +def create_softcap_scoremod(softcap_val): + inv_softcap = 1.0 / softcap_val + + @cute.jit + def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers): + scores = acc_S_SSA * inv_softcap + return scores * cute.math.tanh(scores, fastmath=True) + + return scoremod_premask_fn + + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -21,44 +74,40 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) +def convert_from_dlpack_leading_static( + x, leading_dim, alignment=16, static_modes=None, stride_order=None +) -> cute.Tensor: + if stride_order is None: + stride_order = x.dim_order() + x_ = from_dlpack(x, assumed_align=alignment) + for i in range(x.ndim): + if i != leading_dim and (static_modes is None or i not in static_modes): + x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) + return x_ + + def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if swapAB: - return make_tiled_copy_B(copy_atom, tiled_mma) + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) else: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_A_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - ) + return cute.make_tiled_copy_A(copy_atom, tiled_mma) def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if swapAB: - return make_tiled_copy_A(copy_atom, tiled_mma) + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) else: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) - - -def make_tiled_copy_C(copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma) -> cute.TiledCopy: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_C_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), - ) + return cute.make_tiled_copy_B(copy_atom, tiled_mma) def mma_make_fragment_A( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if swapAB: + if const_expr(swapAB): return mma_make_fragment_B(smem, thr_mma) else: return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) @@ -67,112 +116,174 @@ def mma_make_fragment_A( def mma_make_fragment_B( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if swapAB: + if const_expr(swapAB): return mma_make_fragment_A(smem, thr_mma) else: return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) -def get_smem_store_atom(arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric]) -> cute.CopyAtom: - if arch < 90: +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), element_type, num_bits_per_copy=2 * element_type.width, + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, ) else: return cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), element_type, + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, ) - -def max_constexpr( - a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric] -) -> cutlass.Constexpr[cute.Numeric]: - return a if a > b else b - - +@cute.jit def warp_reduce( val: cute.TensorSSA | cute.Numeric, op: Callable, - width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: - if isinstance(val, cute.TensorSSA): + if const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) - for i in range(cute.size(val.shape)): + for i in cutlass.range_constexpr(cute.size(val.shape)): res[i] = warp_reduce(res[i], op, width) return res.load() else: - for i in range(int(math.log2(width))): + for i in cutlass.range_constexpr(int(math.log2(width))): val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) return val -def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: """ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). """ acc_layout_col_major = cute.make_layout(acc_layout.shape) - acc_layout_mn = cute.make_layout( + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M ( - (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M - (acc_layout_col_major.shape[0][0], *acc_layout_col_major.shape[0][2:], acc_layout_col_major.shape[2]), # MMA_N - *acc_layout_col_major.shape[3:], - ), - stride=( - (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M - (acc_layout_col_major.stride[0][0], *acc_layout_col_major.stride[0][2:], acc_layout_col_major.stride[2]), # MMA_N - *acc_layout_col_major.stride[3:], - ), + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) return cute.composition(acc_layout, acc_layout_mn) -def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: - return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) +@cute.jit def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. - # Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) - acc_layout_divided = cute.logical_divide(acc_layout, (None, None, 2)) - rA_mma_view = cute.make_layout( - ( - (acc_layout_divided.shape[0], acc_layout_divided.shape[2][0]), - acc_layout_divided.shape[1], - acc_layout_divided.shape[2][1], - ), - stride=( - (acc_layout_divided.stride[0], acc_layout_divided.stride[2][0]), - acc_layout_divided.stride[1], - acc_layout_divided.stride[2][1], - ), - ) + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) return rA_mma_view +def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + def transpose_view(a: cute.Tensor) -> cute.Tensor: - """Transpose the first two dimensions of a tensor on smem. - """ + """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) order = (1, 0, *range(2, cute.rank(a))) return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + # stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:]) + # return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) -def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32: - """exp2f calculation for both vector and scalar. +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return cute.make_swizzle(b, m, s) + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +@cute.jit +def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: + """exp2f calculation for both vector and scalar. :param x: input value - :type x: cute.TensorSSA or cutlass.Float32 + :type x: cute.TensorSSA or Float32 :return: exp2 value - :rtype: cute.TensorSSA or cutlass.Float32 + :rtype: cute.TensorSSA or Float32 """ - if isinstance(x, cute.TensorSSA): - res = cute.make_fragment(x.shape, cutlass.Float32) + if const_expr(isinstance(x, cute.TensorSSA)): + res = cute.make_fragment(x.shape, Float32) res.store(x) - for i in range(cute.size(x.shape)): + for i in cutlass.range_constexpr(cute.size(x.shape)): res[i] = cute.arch.exp2(res[i]) return res.load() else: @@ -180,11 +291,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float @dsl_user_op -def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32: - return cutlass.Float32( +def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( llvm.inline_asm( T.f32(), - [cutlass.Float32(a).ir_value(loc=loc, ip=ip)], + [Float32(a).ir_value(loc=loc, ip=ip)], "lg2.approx.ftz.f32 $0, $1;", "=f,f", has_side_effects=False, @@ -195,15 +306,126 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32: @dsl_user_op -def atomic_add_fp32( - a: float | cutlass.Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None -) -> None: +def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: + return log2f(a, loc=loc, ip=ip) * math.log(2.0) + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + # if const_expr(init_val is None): + # init_val = -cutlass.Float32.if + # return x.reduce(cute.ReductionOp.MAX, init_val, 0) + res = cute.make_fragment(x.shape, Float32) + res.store(x) + # local_max = [res[0], res[1]] + # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): + # local_max[0] = fmax(local_max[0], res[i + 0]) + # local_max[1] = fmax(local_max[1], res[i + 1]) + # local_max[0] = fmax(local_max[0], local_max[1]) + # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max + # We instead force the 3-input max. + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + # res = cute.make_fragment(x.shape, Float32) + # res.store(x) + # local_sum = [res[0], res[1], res[2], res[3]] + # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + # local_sum[0] += res[i + 0] + # local_sum[1] += res[i + 1] + # local_sum[2] += res[i + 2] + # local_sum[3] += res[i + 3] + # local_sum[0] += local_sum[1] + # local_sum[2] += local_sum[3] + # local_sum[0] += local_sum[2] + # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val + else: + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_sum_0 = ( + add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # # cache_hint = cutlass.Int64(0x12F0000000000000) # llvm.inline_asm( # None, - # [gmem_ptr_i64, cutlass.Float32(a).ir_value(loc=loc, ip=ip)], - # # [gmem_ptr_i64, cutlass.Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], + # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], # "red.global.add.f32 [$0], $1;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", @@ -214,10 +436,7 @@ def atomic_add_fp32( # asm_dialect=llvm.AsmDialect.AD_ATT, # ) nvvm.atomicrmw( - res=T.f32(), - op=nvvm.AtomicOpKind.FADD, - ptr=gmem_ptr.llvm_ptr, - a=cutlass.Float32(a).ir_value() + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() ) @@ -226,82 +445,54 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) +@dsl_user_op +def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(x.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + # HACK: we assume that applying the offset does not change the pointer alignment + byte_offset = offset * x.element_type.width // 8 + return cute.make_ptr( + x.element_type, + x.iterator.toint() + byte_offset, + x.memspace, + assumed_align=x.iterator.alignment, + ) + + +@cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" tApA = cute.make_fragment( cute.make_layout( - (tAcA.shape[0][1], cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) - for rest_v in range(tApA.shape[0]): - for rest_k in range(tApA.shape[2]): + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) return tApA -@dsl_user_op -def barrier_sync(barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, - *, loc=None, ip=None) -> None: - llvm.inline_asm( - None, - [cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip), cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip)], - "bar.sync $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def barrier_arrive(barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None) -> None: - """ - Arrive at a named barrier. - """ - barrier_id = cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip) - number_of_threads = cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip) - nvvm.barrier_arrive( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) - # llvm.inline_asm( - # None, - # [barrier_id, number_of_threads], - # "bar.arrive $0, $1;", - # "r,r", - # has_side_effects=True, - # is_align_stack=False, - # asm_dialect=llvm.AsmDialect.AD_ATT, - # ) - - -@dsl_user_op -def cp_async_mbarrier_arrive_shared( - mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None -) -> None: - nvvm.cp_async_mbarrier_arrive_shared( - mbar_ptr.llvm_ptr, - noinc=noinc, - loc=loc, - ip=ip, - ) - - def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: warp_group_idx = cute.arch.thread_idx()[0] // 128 - if cutlass.const_expr(sync): + if const_expr(sync): warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) return warp_group_idx # @dsl_user_op -# def warp_vote_any_lt(a: float | cutlass.Float32, b: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Boolean: +# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: # mask = cutlass.Int32(-1) # return cutlass.Boolean( # llvm.inline_asm( # T.i32(), -# [cutlass.Float32(a).ir_value(loc=loc, ip=ip), cutlass.Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], +# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], # ".pred p1, p2;\n" # "setp.lt.f32 p1, $1, $2;\n" # "vote.sync.any.pred p2, p1, $3;\n" @@ -315,14 +506,11 @@ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: # ) -@dsl_user_op +@cute.jit def shuffle_sync( value: cute.Numeric, offset: cute.typing.Int, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, - *, - loc=None, - ip=None ) -> cute.Numeric: assert value.width % 32 == 0, "value type must be a multiple of 32 bits" # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 @@ -332,6 +520,322 @@ def shuffle_sync( val = cute.make_fragment(1, type(value)) val[0] = value val_i32 = cute.recast_tensor(val, cutlass.Int32) - for i in range(cute.size(val_i32)): + for i in cutlass.range_constexpr(cute.size(val_i32)): val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) return val[0] + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_fragment(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: + # We assume x <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version +@dsl_user_op +def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) + xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def coord_offset_i64( + tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + new_layout = cute.slice_( + tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)) + ) + return cute.make_tensor(new_ptr, new_layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_fragment(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1e041e4538d..865f1db5432 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -127,7 +127,10 @@ def _flash_attn_forward_fake( softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -220,10 +223,11 @@ def _flash_attn_varlen_forward_fake( out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -315,7 +319,10 @@ def _flash_attn_backward_fake( if dv is None: dv = torch.empty_like(v) batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) return softmax_d @@ -426,7 +433,10 @@ def _flash_attn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d @@ -1576,7 +1586,7 @@ def flash_attn_with_kvcache( softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py index c1e2ff5985f..5cc93edc5e4 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -1161,7 +1161,7 @@ def attention_prefill_backward_triton_split_impl( delta = torch.zeros_like(softmax_lse) if IS_VARLEN: stride_deltab = 0 - stride_deltam, stride_deltah = delta.stride() + stride_deltah, stride_deltam = delta.stride() else: stride_deltab, stride_deltah, stride_deltam = delta.stride() pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 3f2d92c22d6..5c16cf4c552 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_rdna def get_cdna_autotune_configs(): return [ @@ -23,6 +23,26 @@ def get_cdna_autotune_configs(): num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] +def get_rdna_autotune_configs(): + return [ + # Most aggressive - 128x128 (best for large sequences) + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Large blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # Medium blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_cdna(): @@ -30,6 +50,11 @@ def get_autotune_configs(): fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + elif is_rdna(): + autotune_configs, autotune_keys = get_rdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) else: raise ValueError("Unknown Device Type") else: diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index dec5673e3e5..b0b320321b6 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,24 +186,22 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), + # Best config from autotune on gfx1100: 32x16, warps=2, PRE_LOAD_V=True + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + # === Configs for head_dim=128 === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + # === Fallback configs === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_rdna(): @@ -214,8 +212,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ + # Optimized for gfx1100 (RDNA3) with LLC-aware head grouping triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), @@ -621,8 +620,9 @@ def attention_prefill_forward_triton_impl( # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: - softmax_lse = torch.zeros((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) - stride_lse_m, stride_lse_h = softmax_lse.stride() + total_seqlen_q, _, _ = q.shape + softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) + stride_lse_h, stride_lse_m = softmax_lse.stride() stride_lse_z = 0 else: softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index bb6e25b509c..c223ee93b6c 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -9,11 +9,17 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from .l2_cache_aware import is_head_grouping_beneficial, print_head_grouping_info from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union -def fwd(q: torch.Tensor, +# Environment variable to enable verbose head grouping output +L2_HEAD_GROUPING_DEBUG = os.environ.get('FLASH_ATTN_HEAD_GROUPING_DEBUG', '0') == '1' + + + +def _fwd_single_group(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor], @@ -31,6 +37,7 @@ def fwd(q: torch.Tensor, descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original fwd implementation for a single head group.""" if DEBUG: print() @@ -74,11 +81,9 @@ def fwd(q: torch.Tensor, if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - else: - rng_state = None + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast # check arguments metadata.check_args(q, k, v, out) @@ -147,6 +152,112 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: bshd (batch, seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is bshd: [batch, seqlen, heads, head_dim] + batch, seqlen_q, nheads_q, head_dim = q.shape + seqlen_k = k.shape[1] + nheads_k = k.shape[2] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _fwd_single_group( + q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: bshd layout -> select heads on dim 2 + q_group = q[:, :, start_h:end_h, :].contiguous() + k_group = k[:, :, start_h:end_h, :].contiguous() + v_group = v[:, :, start_h:end_h, :].contiguous() + out_group = out[:, :, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _fwd_single_group( + q_group, k_group, v_group, out_group, alibi_group, + dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, :, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=1) # Assuming lse is [batch, heads, ...] + + return out, softmax_lse, None, rng_state + + + BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, @@ -212,8 +323,7 @@ def bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p > 0.0: - assert rng_state is not None + if rng_state is not None: philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None @@ -352,7 +462,7 @@ def bwd( print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta -def varlen_fwd( +def _varlen_fwd_single_group( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -379,6 +489,7 @@ def varlen_fwd( descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original varlen_fwd implementation for a single head group.""" if DEBUG: print() @@ -423,11 +534,9 @@ def varlen_fwd( if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - else: - rng_state = None + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast # Check arguments metadata.check_args(q, k, v, out) @@ -495,6 +604,121 @@ def varlen_fwd( return out, softmax_lse, sd_mask, rng_state + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Variable-length flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: thd (total_seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is thd: [total_seqlen, heads, head_dim] + total_seqlen, nheads_q, head_dim = q.shape + nheads_k = k.shape[1] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _varlen_fwd_single_group( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, + block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, zero_tensors, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping varlen] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: thd layout -> select heads on dim 1 + q_group = q[:, start_h:end_h, :].contiguous() + k_group = k[:, start_h:end_h, :].contiguous() + v_group = v[:, start_h:end_h, :].contiguous() + out_group = out[:, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _varlen_fwd_single_group( + q_group, k_group, v_group, out_group, cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k, block_table_, alibi_group, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + zero_tensors, causal, window_size_left, window_size_right, + softcap, return_softmax, gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=0) # varlen lse is [heads, total_seqlen] + + return out, softmax_lse, None, rng_state + def varlen_bwd( dout: torch.Tensor, q: torch.Tensor, @@ -563,8 +787,7 @@ def varlen_bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p > 0.0: - assert rng_state is not None + if rng_state is not None: philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py new file mode 100644 index 00000000000..981f8ce5702 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -0,0 +1,271 @@ +""" +Infinity Cache (LLC) Aware Head Grouping for Flash Attention + +This module provides functionality to optimize flash attention by processing +heads in groups that fit in the Last Level Cache (LLC / Infinity Cache). + +AMD RDNA3 cache hierarchy: +- L2 Cache: 6 MB (per-die, fast) +- Infinity Cache (L3/LLC): 96 MB (acts as memory-side cache) + +For large sequence lengths, we want K,V to fit in the 96 MB Infinity Cache. +By processing heads in groups that fit, we achieve up to 2x speedup. + +Example: gfx1100 with 96MB Infinity Cache, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB LLC) +- K,V for 10 heads = 88 MB (fits in 96 MB LLC) +- Processing 10 heads at a time gives 1.95x speedup +""" + +import os +import functools +from typing import Optional, Tuple, Dict +import torch + +# Infinity Cache (LLC) sizes for AMD GPUs in bytes +# Note: This is the L3/Infinity Cache, NOT the L2 cache +# RDNA3: L2=6MB, Infinity Cache (LLC)=96MB +AMD_LLC_CACHE_SIZES: Dict[str, int] = { + # RDNA3 consumer + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB Infinity Cache + "gfx1101": 64 * 1024 * 1024, # RX 7800 XT - 64 MB Infinity Cache + "gfx1102": 32 * 1024 * 1024, # RX 7600 - 32 MB Infinity Cache +} + +# Legacy alias for backwards compatibility +AMD_L2_CACHE_SIZES = AMD_LLC_CACHE_SIZES + +# Environment variable to override LLC cache size (in MB) +LLC_CACHE_OVERRIDE_ENV = "FLASH_ATTN_LLC_CACHE_MB" +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" # Legacy alias + +# Environment variable to disable head grouping +DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" + +# Cached LLC size per device +_llc_cache_size_cache: Dict[int, int] = {} + + +@functools.lru_cache(maxsize=None) +def get_gcn_arch_name(device_index: int = 0) -> str: + """Get the GCN architecture name for an AMD GPU.""" + try: + props = torch.cuda.get_device_properties(device_index) + if hasattr(props, 'gcnArchName'): + return props.gcnArchName + # Fallback: try to get from name + name = props.name.lower() + if 'gfx' in name: + # Extract gfxXXXX from name + import re + match = re.search(r'gfx\d+', name) + if match: + return match.group() + except Exception: + pass + return "unknown" + + +def get_num_cus(device_index: int = 0) -> int: + """ + Get the number of Compute Units for an AMD GPU. + + Note: PyTorch's multi_processor_count may be incorrect for some AMD GPUs. + We use known values for common architectures. + """ + arch = get_gcn_arch_name(device_index) + + # Known CU counts for common GPUs + known_cus = { + "gfx1100": 96, # RX 7900 XTX + "gfx1101": 60, # RX 7800 XT + "gfx1102": 32, # RX 7600 + } + + if arch in known_cus: + return known_cus[arch] + + # Fallback to PyTorch (may be incorrect) + try: + props = torch.cuda.get_device_properties(device_index) + return props.multi_processor_count + except Exception: + return 96 # Default + + +def get_llc_cache_size(device_index: int = 0) -> int: + """ + Get Infinity Cache (LLC) size for the specified GPU device. + + For RDNA3, this is the 96 MB Infinity Cache, not the 6 MB L2. + + Returns: + LLC cache size in bytes + """ + global _llc_cache_size_cache + + if device_index in _llc_cache_size_cache: + return _llc_cache_size_cache[device_index] + + # Check for environment override (new name first, then legacy) + for env_var in [LLC_CACHE_OVERRIDE_ENV, L2_CACHE_OVERRIDE_ENV]: + if env_var in os.environ: + try: + size_mb = int(os.environ[env_var]) + size_bytes = size_mb * 1024 * 1024 + _llc_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass + + # Get architecture and look up cache size + arch = get_gcn_arch_name(device_index) + + # Check exact match first + if arch in AMD_LLC_CACHE_SIZES: + size = AMD_LLC_CACHE_SIZES[arch] + _llc_cache_size_cache[device_index] = size + return size + + # Check prefix match (e.g., gfx1100 matches gfx1100) + for known_arch, size in AMD_LLC_CACHE_SIZES.items(): + if arch.startswith(known_arch): + _llc_cache_size_cache[device_index] = size + return size + + # Default: assume 96 MB (conservative for RDNA3) + default_size = 96 * 1024 * 1024 + _llc_cache_size_cache[device_index] = default_size + return default_size + + +# Legacy alias +get_l2_cache_size = get_llc_cache_size + + +def calculate_optimal_head_group_size( + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + llc_utilization: float = 1.5 # Use 150% of LLC - optimal for long sequences +) -> int: + """ + Calculate the optimal number of heads to process together to fit K,V in LLC. + """ + llc_size = get_llc_cache_size(device_index) + + # Get element size in bytes + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 # Default to fp16 + + # Memory for K and V per head + kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V + + # Target LLC usage + target_llc = int(llc_size * llc_utilization) + + # Calculate number of heads that fit + if kv_per_head == 0: + return 1 + + head_group_size = max(1, target_llc // kv_per_head) + + return head_group_size + + +def is_head_grouping_beneficial( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + threshold_ratio: float = 1.5 +) -> Tuple[bool, int]: + """ + Determine if head grouping would be beneficial and return optimal group size. + """ + # Check if disabled via environment + if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": + return False, nheads + + llc_size = get_llc_cache_size(device_index) + + # Get element size + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + # Total K,V memory for all heads + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + + # Only group if K,V significantly exceeds LLC + if total_kv < llc_size * threshold_ratio: + return False, nheads + + # Calculate optimal group size + group_size = calculate_optimal_head_group_size( + seqlen_k, head_dim, dtype, device_index + ) + + # Only group if we'd have at least 2 groups + if group_size >= nheads: + return False, nheads + + # Minimum group size to avoid excessive kernel launches + min_group_size = max(1, nheads // 16) # At most 16 groups + group_size = max(group_size, min_group_size) + + return True, min(group_size, nheads) + + +def print_head_grouping_info( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0 +): + """Print diagnostic information about head grouping.""" + llc_size = get_llc_cache_size(device_index) + arch = get_gcn_arch_name(device_index) + num_cus = get_num_cus(device_index) + + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + should_group, group_size = is_head_grouping_beneficial( + nheads, seqlen_k, head_dim, dtype, device_index + ) + + print(f"\n=== Infinity Cache (LLC) Aware Head Grouping ===") + print(f"GPU: {arch} ({num_cus} CUs)") + print(f"Infinity Cache (LLC): {llc_size / (1024*1024):.1f} MB") + print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") + print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") + print(f"LLC Ratio: {total_kv / llc_size:.2f}x") + print(f"Should Group: {should_group}") + if should_group: + kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 + num_groups = (nheads + group_size - 1) // group_size + print(f"Group Size: {group_size} heads ({num_groups} groups)") + print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") + print("=" * 48 + "\n") diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 0300e3902a1..167b99e0d81 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -48,7 +48,7 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False + use_exp2: bool = True rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False @@ -112,11 +112,10 @@ def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): self.rotary_interleaved = rotary_interleaved self.rotary_conjunction = rotary_conjunction - def need_dropout(self, dropout_p, return_scores = True): - if dropout_p > 0.0: - self.dropout_p = dropout_p - self.return_scores = return_scores - self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 + def need_dropout(self, dropout_p, return_softmax = True): + self.dropout_p = dropout_p + self.return_softmax = return_softmax + self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() diff --git a/flash_attn/fused_softmax.py b/flash_attn/fused_softmax.py deleted file mode 100644 index 382f94f092c..00000000000 --- a/flash_attn/fused_softmax.py +++ /dev/null @@ -1,201 +0,0 @@ -# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py -# for benchmarking. -# We added support for seqlen=2k and seqlen=4k - -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType -from fused_softmax_lib import ( - scaled_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_get_batch_per_block, - scaled_upper_triang_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, -) - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 8192 # sk must be 16 ~ 8192 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 8192: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/flash_attn/pyproject.toml b/flash_attn/pyproject.toml index 3201555763e..ce5eac916cd 100644 --- a/flash_attn/pyproject.toml +++ b/flash_attn/pyproject.toml @@ -1,3 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py38'] \ No newline at end of file +target-version = 'py39' +[tool.ruff] +line-length = 100 +target-version = 'py39' \ No newline at end of file diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py index 339af1767c4..81be51f1de8 100644 --- a/flash_attn/utils/testing.py +++ b/flash_attn/utils/testing.py @@ -1,10 +1,11 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math +from typing import Optional import torch from einops import rearrange, repeat -from padding import pad_input, unpad_input +from flash_attn.bert_padding import pad_input, unpad_input def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): @@ -158,7 +159,7 @@ def generate_qkv( def construct_local_mask( seqlen_q, seqlen_k, - window_size=(-1, -1), # -1 means infinite window size + window_size=(None, None), sink_token_length=0, query_padding_mask=None, key_padding_mask=None, @@ -181,7 +182,7 @@ def construct_local_mask( if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) - if window_size[0] < 0: + if window_size[0] is None: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk @@ -237,9 +238,10 @@ def attention_ref( causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), # -1 means infinite window size + window_size=(None, None), attention_chunk=0, sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, @@ -297,7 +299,7 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) local_mask = None - if window_size[0] >= 0 or window_size[1] >= 0: + if window_size[0] is not None or window_size[1] is not None: local_mask = construct_local_mask( seqlen_q, seqlen_k, @@ -323,7 +325,16 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) + attention = (unnormalized_scores / normalizer).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 33e5d282716..e94d325d42d 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -68,7 +68,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index 6d9b5f4f596..fdae7616683 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -109,6 +109,7 @@ struct CollectiveEpilogueBwd { Element* ptr_dV; ShapedKV const shape_dV; StridedKV const stride_dV; + int const num_batch; int const num_heads_q; int* dk_semaphore; int* dv_semaphore; @@ -369,7 +370,8 @@ struct CollectiveEpilogueBwdGQA { ElementAccum* ptr_dVaccum; ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; - int num_heads_q; + int const num_batch; + int const num_heads_q; int* dk_semaphore; int* dv_semaphore; int const* cu_seqlens; @@ -387,6 +389,7 @@ struct CollectiveEpilogueBwdGQA { cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; int* dv_semaphore; + int const num_batch; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; @@ -400,7 +403,7 @@ struct CollectiveEpilogueBwdGQA { return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, - args.cu_seqlens, args.seqused}; + args.num_batch, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -449,8 +452,8 @@ struct CollectiveEpilogueBwdGQA { cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); } - // int const num_batch = params.num_batch; - int const num_batch = get<2>(params.shape_dKaccum); + int const num_batch = params.num_batch; + // int const num_batch = get<2>(params.shape_dKaccum); // erroneously returns 1 for varlen int const num_head_kv = get<1>(params.shape_dKaccum); int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; using Barrier = cutlass::GenericBarrier; diff --git a/hopper/flash.h b/hopper/flash.h index bee89e5f054..6848e8c9dbd 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -152,10 +152,16 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; + int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; + int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual + int * __restrict__ num_nheads_in_l2_ptr; bool skip_scheduler_metadata_computation; + bool varlen_sort_batches; + int tile_count_semaphore_offset; + bool head_swizzle; + bool prepare_varlen_pdl; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 33185bf2304..7ab4352984e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -39,6 +39,14 @@ PyObject* PyInit__C(void) #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + +namespace { +inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { + return at::cuda::CUDAGuard(static_cast(t.get_device())); +} +} // namespace + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -250,6 +258,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -257,6 +266,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -268,11 +278,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -283,6 +295,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -290,6 +303,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -301,11 +315,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -329,11 +345,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } + #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif @@ -525,8 +543,7 @@ mha_fwd_get_scheduler_metadata( bool has_softcap, int64_t num_splits, std::optional pack_gqa_, - int64_t sm_margin - ) { + int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); @@ -585,8 +602,9 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -597,24 +615,41 @@ mha_fwd_get_scheduler_metadata( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(seqused_k); auto opts = seqused_k.options(); // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::empty( + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + opts.dtype(torch::kInt32)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; } - if (params.num_splits_dynamic_ptr) { + if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); @@ -847,7 +882,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); at::Tensor softmax_lse; if (!is_varlen_q) { @@ -938,11 +973,11 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -955,8 +990,17 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); @@ -968,15 +1012,22 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } else { tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } - if (scheduler_needs_semaphore && !use_dynamic_split) { + if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1134,7 +1185,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1213,7 +1264,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( +std::tuple mha_bwd( at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k @@ -1316,6 +1367,7 @@ std::tuplemajor * 10 + at::cuda::getCurrentDeviceProperties()->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; + TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) @@ -1417,7 +1469,7 @@ std::tuple(); if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } @@ -1511,7 +1563,7 @@ std::tuple @@ -1675,7 +1727,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," - "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp new file mode 100644 index 00000000000..5ae58bdd129 --- /dev/null +++ b/hopper/flash_api_stable.cpp @@ -0,0 +1,1987 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include + +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + +#include +#include +#include +#include +#include + +// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h +extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + +#include +#include + +#include +#include +#include +#include + +using torch::stable::Tensor; +namespace tsa = torch::stable::accelerator; + +namespace { +inline tsa::DeviceGuard make_device_guard(const Tensor& t) { + return tsa::DeviceGuard(static_cast(t.get_device())); +} +std::deque device_flags; +std::vector device_properties; + +void initVectors() { + static bool init_flag [[maybe_unused]] = []() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_flags.resize(device_count); + device_properties.resize(device_count); + return true; + }(); +} + +void initDeviceProperty(int device_index) { + cudaDeviceProp device_prop{}; + cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_properties[device_index] = device_prop; +} + +// Helper function to get device properties using raw CUDA APIs +cudaDeviceProp* get_device_prop() { + initVectors(); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDevice failed: " + + std::string(cudaGetErrorString(err))); + } + + std::call_once(device_flags[device_index], initDeviceProperty, device_index); + return &device_properties[device_index]; +} +} // anonymous namespace + + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the STABLE_TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + +#define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) \ + do { \ + auto expected_dims = std::vector{__VA_ARGS__}; \ + STD_TORCH_CHECK(x.dim() == static_cast(expected_dims.size()), #x " must have " + std::to_string(expected_dims.size()) + " dimensions, got " + std::to_string(x.dim())); \ + for (size_t i = 0; i < expected_dims.size(); ++i) { \ + STD_TORCH_CHECK(x.size(i) == expected_dims[i], #x " dimension " + std::to_string(i) + " must have size " + std::to_string(expected_dims[i]) + ", got " + std::to_string(x.size(i))); \ + } \ + } while (0) +#define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = q.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + STD_TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + STD_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + const Tensor out, + const Tensor dout, + Tensor dq, + Tensor dk, + Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + bool deterministic=false, + int const sm_margin=0) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + #endif + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + STD_TORCH_CHECK(params.num_splits >= 1); + ARCH_SWITCH(params.arch, Arch, [&] { + SPLIT_SWITCH(params.num_splits > 1, Split, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { + PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; + SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { + run_mha_fwd_constexpr(params, stream); + }); + }); + }); + }); + }); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHATTENTION_DISABLE_SPLIT + // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + // so that kBlockM is smaller and we have more parallelism. + if (params.is_fp32) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else if (params.is_bf16) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } + #else + STD_TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + +inline int get_max_headdim() { + #ifndef FLASHATTENTION_DISABLE_HDIM256 + return 256; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + return 192; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + return 128; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + return 96; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM64 + return 64; + #endif + return 0; +} + +inline int round_up_headdim(int head_size) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (head_size <= 96) { return 96; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (head_size <= 192) { return 192; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (head_size <= 256) { return 256; } + #endif + return 256; +} + +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + torch::headeronly::ScalarType qkv_dtype, + Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin) { + + STD_TORCH_CHECK(qkv_dtype == torch::headeronly::ScalarType::Half || qkv_dtype == torch::headeronly::ScalarType::BFloat16 || qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? static_cast(cu_seqlens_q_.value().data_ptr()) : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? static_cast(cu_seqlens_k_.value().data_ptr()) : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? static_cast(cu_seqlens_k_new_.value().data_ptr()): nullptr; + params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; + params.seqused_k = static_cast(seqused_k.data_ptr()); + params.leftpad_k = leftpad_k_.has_value() ? static_cast(leftpad_k_.value().data_ptr()) : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(seqused_k); + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::stable::new_empty( + seqused_k, + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + std::make_optional(torch::headeronly::ScalarType::Int)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + if (scheduler_needs_semaphore) { + if (!use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset; + } else { + params.tile_count_semaphore = nullptr; + } + } + + if (use_prepare_varlen) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple +mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16 || q_type == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + STD_TORCH_CHECK(page_table.scalar_type() == torch::headeronly::ScalarType::Int, "page_table must have dtype torch.int32"); + STD_TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + STD_TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + STD_TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + const int batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + if (!kv_batch_idx_.has_value()) { + STD_TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + STD_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + STD_TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + int const alignment = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? 16 : 8; + STD_TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + STD_TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto out_type = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? torch::headeronly::ScalarType::BFloat16 : q_type; + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q + ? torch::stable::new_empty(q, {batch_size, seqlen_q, num_heads, head_size_v}, std::make_optional(out_type)) + : torch::stable::new_empty(q, {total_q, num_heads, head_size_v}, std::make_optional(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(q); + + Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } + if (paged_KV) { + params.page_table = static_cast(page_table.data_ptr()); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + Tensor k_new, v_new; + STD_TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + STD_TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + STD_TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); + STD_TORCH_CHECK(cu_seqlens_k_new.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + STD_TORCH_CHECK(k_new.scalar_type() == q_type, "k_new must have the same dtype as query"); + STD_TORCH_CHECK(v_new.scalar_type() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); + STD_TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + STD_TORCH_CHECK(scheduler_metadata.scalar_type() == torch::headeronly::ScalarType::Int, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::stable::new_empty(q, {metadata_size}, torch::headeronly::ScalarType::Int); + } + if (scheduler_needs_semaphore && !use_prepare_varlen) { + torch::stable::zero_(tile_count_semaphore); // If varlen we'll manually do the zero-ing + } + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later + } + + if (q_v_.has_value()) { + STD_TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + STD_TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + STD_TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + Tensor q_v = q_v_.value(); + STD_TORCH_CHECK(q_v.scalar_type() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + STD_TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + STD_TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + STD_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + STD_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + STD_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + STD_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + STD_TORCH_CHECK(seqlens_rotary.scalar_type() == torch::headeronly::ScalarType::Int, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = static_cast(seqlens_rotary.data_ptr()); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); + STD_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + Tensor out_accum, softmax_lse_accum; + auto outaccum_type = torch::headeronly::ScalarType::Float; + if (params.num_splits > 1) { + STD_TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + if (q_type == torch::headeronly::ScalarType::Float8_e4m3fn) { + if (q_descale_.has_value()) { + auto q_descale = q_descale_.value(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + params.q_descale_ptr = static_cast(q_descale.data_ptr()); + params.q_descale_batch_stride = q_descale.stride(0); + params.q_descale_head_stride = q_descale.stride(1); + } else { + params.q_descale_ptr = nullptr; + } + if (k_descale_.has_value()) { + auto k_descale = k_descale_.value(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + params.k_descale_ptr = static_cast(k_descale.data_ptr()); + params.k_descale_batch_stride = k_descale.stride(0); + params.k_descale_head_stride = k_descale.stride(1); + } else { + params.k_descale_ptr = nullptr; + } + if (v_descale_.has_value()) { + auto v_descale = v_descale_.value(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + params.v_descale_ptr = static_cast(v_descale.data_ptr()); + params.v_descale_batch_stride = v_descale.stride(0); + params.v_descale_head_stride = v_descale.stride(1); + } else { + params.v_descale_ptr = nullptr; + } + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + #ifdef FLASHATTENTION_DISABLE_SPLIT + STD_TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + #endif + #ifdef FLASHATTENTION_DISABLE_PACKGQA + STD_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + #endif + #ifdef FLASHATTENTION_DISABLE_PAGEDKV + STD_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + #endif + #ifdef FLASHATTENTION_DISABLE_APPENDKV + STD_TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + #endif + + if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_fwd(params, stream); + if (params.num_splits > 1) { + if (out_type == torch::headeronly::ScalarType::BFloat16) { + // Since we want output in BF16. Otherwise fwd_combine will output to FP16 + params.is_bf16 = true; + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 + // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. + // if (is_varlen_q && !seqused_q_.has_value()) { + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + auto slice = torch::stable::narrow(tile_count_semaphore, 0, params.tile_count_semaphore_offset, 1); + torch::stable::zero_(slice); + } + } else if (total_q > 0 && num_heads_k > 0) { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(out); + torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); + } + + // return {out, softmax_lse}; + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + STD_TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + ARCH_SWITCH(params.arch, Arch, [&] { + SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { + run_mha_bwd_constexpr(params, stream); + }); + }); +} +#endif + + +// b: batch_size +// s_q: seqlen_q +// s_k: seqlen_k +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple mha_bwd( + Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + STD_TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention only support fp16 and bf16 data type"); + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + STD_TORCH_CHECK(out.scalar_type() == q_type, "query and out must have the same dtype"); + STD_TORCH_CHECK(dout.scalar_type() == q_type, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + // auto const sizes = q.sizes(); + int const batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int const seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int const total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int const num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + STD_TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + if (is_causal) { window_size_right = 0; } + // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. + // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = dprops->major * 10 + dprops->minor; + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; + STD_TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); + // Very important that these match the kernel configs + bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) + : (head_size_rounded <= 96 ? 64 + : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) + : 64)); + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + int const kBlockN_sm90 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 96 : 80); + int const kBlockN_sm80 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 + : (head_size_rounded <= 96 ? 128 + : (head_size_rounded <= 128 ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()){ + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + STD_TORCH_CHECK(dq.scalar_type() == q_type, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } + } else { + dq = torch::stable::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + STD_TORCH_CHECK(dk.scalar_type() == q_type, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } + } else { + dk = torch::stable::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + STD_TORCH_CHECK(dv.scalar_type() == q_type, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); + } + } else { + dv = torch::stable::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(q); + + // auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + Tensor softmax_d, softmax_lse_log2; + if (!is_varlen) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_d = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + Tensor dq_accum, dk_accum, dv_accum; + if (!is_varlen) { + dq_accum = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + dq_accum = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } else { + dk_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } + } + + Flash_bwd_params params; + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk, dv, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, + num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + 0, // attention_chunk + softcap, + deterministic, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::headeronly::ScalarType::Int)) : torch::empty({1}, opts.dtype(torch::headeronly::ScalarType::Int)); + // params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()); + // Will be zero'ed out in the backward preprocess kernel + Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dq_semaphore = static_cast(dq_semaphore.data_ptr()); + if (num_heads_k != num_heads && params.deterministic) { + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + Tensor dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + Tensor dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dk_semaphore = static_cast(dk_semaphore.data_ptr()); + params.dv_semaphore = static_cast(dv_semaphore.data_ptr()); + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_bwd(params, stream); + } else if (total_k > 0 && num_heads_k > 0) { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(dk); + torch::stable::zero_(dv); + torch::stable::zero_(softmax_d); + } else if (total_q > 0 && num_heads_k > 0) { + torch::stable::zero_(dq); + torch::stable::zero_(softmax_d); + } + + return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; +} + +std::tuple +mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); + + auto out_partial_type = out_partial.scalar_type(); + STD_TORCH_CHECK(out_partial_type == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + STD_TORCH_CHECK(lse_partial.scalar_type() == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + + CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); + + STD_TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); + + // const auto sizes = out_partial.sizes(); + + const int num_splits = out_partial.size(0); + const int batch_size = out_partial.size(1); + const int seqlen = out_partial.size(2); + const int num_heads = out_partial.size(3); + const int head_size_og = out_partial.size(4); + STD_TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); + + CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); + CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); + + int const alignment = 4; + Tensor out_partial_padded; + auto pad = [](Tensor x, int alignment) { + return x.size(-1) % alignment == 0 ? x : torch::stable::pad(x, {0, alignment - x.size(-1) % alignment}); + }; + out_partial_padded = pad(out_partial, alignment); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, alignment); + + // auto opts = out_partial.options(); + torch::headeronly::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); + STD_TORCH_CHECK(out_type == torch::headeronly::ScalarType::Float || out_type == torch::headeronly::ScalarType::BFloat16 || out_type == torch::headeronly::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); + if (head_size_og % alignment != 0) { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + } else { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(out_partial); + + auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse = torch::stable::transpose(softmax_lse, 1, 2); + + Flash_fwd_params params {}; // Need to reset the params to set everything to zero + params.is_fp32 = out_type == torch::headeronly::ScalarType::Float; + params.is_bf16 = out_type == torch::headeronly::ScalarType::BFloat16; + params.oaccum_ptr = out_partial_padded.data_ptr(); + params.softmax_lseaccum_ptr = lse_partial.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + params.b = batch_size; + params.h = num_heads; + params.seqlen_q = seqlen; + params.dv = head_size; + params.num_splits = num_splits; + params.oaccum_split_stride = out_partial_padded.stride(0); + params.oaccum_row_stride = out_partial_padded.stride(2); + params.oaccum_head_stride = out_partial_padded.stride(3); + params.oaccum_batch_stride = out_partial_padded.stride(1); + params.lseaccum_split_stride = lse_partial.stride(0); + params.lseaccum_head_stride = lse_partial.stride(3); + params.lseaccum_batch_stride = lse_partial.stride(1); + params.o_row_stride = out.stride(1); + params.o_head_stride = out.stride(2); + params.o_batch_stride = out.stride(0); + params.arch = dprops->major * 10 + dprops->minor; + + if (seqlen > 0 && batch_size > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); + } + + Tensor out_padded = out; + if (head_size_og % alignment != 0) { + out = torch::stable::narrow(out, -1, 0, head_size_og); + // if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + +void boxed_mha_fwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto q = to(stack[0]); + auto k = to(stack[1]); + auto v = to(stack[2]); + auto k_new = to>(stack[3]); + auto v_new = to>(stack[4]); + auto q_v = to>(stack[5]); + auto out = to>(stack[6]); + auto cu_seqlens_q = to>(stack[7]); + auto cu_seqlens_k = to>(stack[8]); + auto cu_seqlens_k_new = to>(stack[9]); + auto seqused_q = to>(stack[10]); + auto seqused_k = to>(stack[11]); + auto max_seqlen_q = to>(stack[12]); + auto max_seqlen_k = to>(stack[13]); + auto page_table = to>(stack[14]); + auto kv_batch_idx = to>(stack[15]); + auto leftpad_k = to>(stack[16]); + auto rotary_cos = to>(stack[17]); + auto rotary_sin = to>(stack[18]); + auto seqlens_rotary = to>(stack[19]); + auto q_descale = to>(stack[20]); + auto k_descale = to>(stack[21]); + auto v_descale = to>(stack[22]); + auto softmax_scale = to>(stack[23]); + auto is_causal = to(stack[24]); + auto window_size_left = to(stack[25]); + auto window_size_right = to(stack[26]); + auto attention_chunk = to(stack[27]); + auto softcap = to(stack[28]); + auto is_rotary_interleaved = to(stack[29]); + auto scheduler_metadata = to>(stack[30]); + auto num_splits = to(stack[31]); + auto pack_gqa = to>(stack[32]); + auto sm_margin = to(stack[33]); + + auto [out_, softmax_lse, out_accum, softmax_lse_accum] = mha_fwd(q, k, v, k_new, v_new, q_v, out, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, is_causal, window_size_left, window_size_right, attention_chunk, softcap, is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin); + + + stack[0] = from(out_); + stack[1] = from(softmax_lse); + stack[2] = from(out_accum); + stack[3] = from(softmax_lse_accum); +} + +void boxed_mha_bwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto dout = to(stack[0]); + auto q = to(stack[1]); + auto k = to(stack[2]); + auto v = to(stack[3]); + auto out = to(stack[4]); + auto softmax_lse = to(stack[5]); + auto dq = to>(stack[6]); + auto dk = to>(stack[7]); + auto dv = to>(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto seqused_q = to>(stack[11]); + auto seqused_k = to>(stack[12]); + auto max_seqlen_q = to>(stack[13]); + auto max_seqlen_k = to>(stack[14]); + auto softmax_scale = to>(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto softcap = to(stack[19]); + auto deterministic = to(stack[20]); + auto sm_margin = to(stack[21]); + + auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); + + stack[0] = from(softmax_d); + stack[1] = from(softmax_lse_log2); + stack[2] = from(dq_accum); + stack[3] = from(dk_accum); + stack[4] = from(dv_accum); +} + +void boxed_mha_combine( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto out_partial = to(stack[0]); + auto lse_partial = to(stack[1]); + auto out = to>(stack[2]); + auto out_dtype = to>(stack[3]); + + auto [out_, softmax_lse] = mha_combine(out_partial, lse_partial, out, out_dtype); + + stack[0] = from(out_); + stack[1] = from(softmax_lse); +} + +void boxed_mha_fwd_get_scheduler_metadata( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto batch_size = to(stack[0]); + auto max_seqlen_q = to(stack[1]); + auto max_seqlen_k = to(stack[2]); + auto num_heads = to(stack[3]); + auto num_heads_k = to(stack[4]); + auto headdim = to(stack[5]); + auto headdim_v = to(stack[6]); + auto qkv_dtype = to(stack[7]); + auto seqused_k = to(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto cu_seqlens_k_new = to>(stack[11]); + auto seqused_q = to>(stack[12]); + auto leftpad_k = to>(stack[13]); + auto page_size = to>(stack[14]); + auto max_seqlen_k_new = to(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto attention_chunk = to(stack[19]); + auto has_softcap = to(stack[20]); + auto num_splits = to(stack[21]); + auto pack_gqa = to>(stack[22]); + auto sm_margin = to(stack[23]); + + auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin); + + stack[0] = from(scheduler_metadata); +} + +STABLE_TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &boxed_mha_fwd); + m.impl("bwd", &boxed_mha_bwd); + m.impl("fwd_combine", &boxed_mha_combine); + m.impl("get_scheduler_metadata", &boxed_mha_fwd_get_scheduler_metadata); +} diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index cfb8881b4b2..44d1f027cb0 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, Tri Dao. -from typing import Optional, Union +from typing import Optional, Union, List, Tuple import torch import torch.nn as nn @@ -17,40 +17,68 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def round_multiple(x, m): + return (x + m - 1) // m * m + + +def round_up_headdim(head_size: int) -> int: + from flash_attn_config import CONFIG + + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: + if head_size <= 64: + return 64 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: + if head_size <= 96: + return 96 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: + if head_size <= 128: + return 128 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: + if head_size <= 192: + return 192 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: + if head_size <= 256: + return 256 + return 256 + + +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size=(-1, -1), - attention_chunk=0, - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -62,14 +90,14 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( q, k, v, k_new, v_new, qv, - out, + out_, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, @@ -88,8 +116,8 @@ def _flash_attn_forward( v_descale, softmax_scale, causal, - window_size[0], - window_size[1], + window_size_left, + window_size_right, attention_chunk, softcap, rotary_interleaved, @@ -98,59 +126,314 @@ def _flash_attn_forward( pack_gqa, sm_margin, ) - return out, softmax_lse, *rest + if out_accum is None: + out_accum = torch.tensor([], device=out.device) + + if softmax_lse_accum is None: + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _flash_attn_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Symbolic fake implementation of flash attention forward. + Returns tensors with the correct shapes and dtypes without actual computation. + """ + + # Determine if we're in varlen mode + is_varlen_q = cu_seqlens_q is not None + # Get dimensions from query tensor + if is_varlen_q: + # varlen mode: q is (total_q, num_heads, head_size) + total_q, num_heads, head_size = q.shape + batch_size = cu_seqlens_q.shape[0] - 1 + + if max_seqlen_q is None: + raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided") + seqlen_q = max_seqlen_q + else: + # batch mode: q is (batch_size, seqlen_q, num_heads, head_size) + batch_size, seqlen_q, num_heads, head_size = q.shape + total_q = batch_size * q.shape[1] + # Get value head dimension + head_size_v = v.shape[-1] + + # Determine output dtype (FP8 inputs produce BF16 outputs) + q_type = q.dtype + if q_type == torch.float8_e4m3fn: + out_dtype = torch.bfloat16 + else: + out_dtype = q_type + + # Create output tensor + if out_ is not None: + # If out_ is provided, _flash_attn_forward becomes non-functional + raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.") + + if is_varlen_q: + out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + else: + out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + + # Create softmax_lse tensor + if is_varlen_q: + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device) + else: + softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + + # TODO(guilhermeleobas): Implement "get_num_splits" + # There's an heuristic to compute num_splits when "num_splits <= 0" + # assert that num_splits is > 0 for now + if num_splits <= 0: + raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}") + + if num_splits > 1: + if is_varlen_q: + out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device) + else: + out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + else: + # Tensors are not set when num_splits < 1 + out_accum = torch.tensor([], device=out.device) + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, k, v, out, softmax_lse, + dq, + dk, + dv, cu_seqlens_q, cu_seqlens_k, sequed_q, sequed_k, max_seqlen_q, max_seqlen_k, - dq, - dk, - dv, softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - sm_margin=0, -): - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + is_causal, + window_size_left, + window_size_right, + softcap, + deterministic, + sm_margin, + ) + return softmax_d + + +@torch.library.register_fake("flash_attn_3::_flash_attn_backward") +def _flash_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + + is_varlen_q = cu_seqlens_q is not None + is_varlen_k = cu_seqlens_q is not None + is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None + + if not is_varlen_q: + batch_size = q.size(0) + seqlen_q = q.size(1) + seqlen_k = k.size(1) + total_q = batch_size * q.size(1) + else: + batch_size = cu_seqlens_q.size(0) - 1 + total_q = q.size(0) + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if window_size_left >= seqlen_k - 1: + window_size_left = -1 + + if window_size_right >= seqlen_q - 1: + window_size_right = -1 + + if is_causal: + window_size_right = 0 + + is_causal = window_size_left < 0 and window_size_right == 0 + + head_size = q.size(-1) + head_size_v = v.size(-1) + head_size_rounded = round_up_headdim(max(head_size, head_size_v)) + + # Hopper gpus uses cuda compute capabilities 9.0 + cap = torch.cuda.get_device_capability(q.device) + arch = cap[0] * 10 + cap[1] + + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + if head_size_rounded <= 64: + kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 + elif head_size_rounded <= 96: + kBlockM_sm90 = 64 + elif head_size_rounded <= 128: + kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80 + else: + kBlockM_sm90 = 64 + + kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64 + kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32 + + if arch >= 90: + kBlockM = kBlockM_sm90 + elif arch == 86 or arch == 89: + kBlockM = kBlockM_sm86 + else: + kBlockM = kBlockM_sm80 + + num_heads = q.shape[-2] + seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) + + total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) + + dq = torch.empty_like(q) if dq is None else dq + dk = torch.empty_like(k) if dk is None else dk + dv = torch.empty_like(v) if dv is None else dv + + if not is_varlen: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device) + else: + softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) + + return softmax_d + + +def setup_context(ctx, inputs, output): + q, k, v = inputs[:3] + out, softmax_lse, _, _ = output + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = inputs[-11] + ctx.causal = inputs[-10] + ctx.window_size = [inputs[-9], inputs[-8]] + ctx.attention_chunk = inputs[-7] + ctx.softcap = inputs[-6] + ctx.sm_margin = inputs[-1] + + +def _backward(ctx, dout, *grads): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( dout, q, k, v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - deterministic, - sm_margin, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + False, # deterministic + ctx.sm_margin, ) - return dq, dk, dv, softmax_d + return dq, dk, dv, *((None,) * 21) + + +_flash_attn_forward.register_autograd(_backward, setup_context=setup_context) + class FlashAttnQKVPackedFunc(torch.autograd.Function): @@ -167,6 +450,7 @@ def forward( deterministic=False, num_heads_q=None, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -194,7 +478,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, sm_margin=sm_margin, @@ -209,8 +494,7 @@ def forward( ctx.deterministic = deterministic ctx.ndim = qkv.dim() ctx.sm_margin = sm_margin - # return out, softmax_lse - return out + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -241,13 +525,14 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -269,6 +554,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -288,7 +574,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -304,7 +591,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -326,7 +613,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -362,6 +650,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -385,7 +674,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -403,7 +693,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -428,7 +718,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -450,6 +741,7 @@ def flash_attn_qkvpacked_func( deterministic=False, num_heads_q=None, sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -496,6 +788,7 @@ def flash_attn_qkvpacked_func( deterministic, num_heads_q, sm_margin, + return_attn_probs, ) @@ -514,6 +807,7 @@ def flash_attn_func( pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -575,6 +869,7 @@ def flash_attn_func( pack_gqa, deterministic, sm_margin, + return_attn_probs, ) @@ -599,6 +894,7 @@ def flash_attn_varlen_func( pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): return FlashAttnVarlenFunc.apply( q, @@ -621,6 +917,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, sm_margin, + return_attn_probs, ) @@ -706,7 +1003,7 @@ def flash_attn_with_kvcache( q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. + page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.). v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate @@ -751,7 +1048,7 @@ def flash_attn_with_kvcache( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) out, softmax_lse, *rest = _flash_attn_forward( @@ -778,7 +1075,8 @@ def flash_attn_with_kvcache( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index b6e8810b25f..6df3231cdd4 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -94,8 +94,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { flash::CollectiveEpilogueBwdGQA >; using Scheduler = std::conditional_t< - Is_causal && !Varlen, - flash::SingleTileBwdLPTScheduler, + Is_causal, + flash::SingleTileBwdLPTScheduler, flash::SingleTileScheduler >; using AttnKernel = std::conditional_t< @@ -165,6 +165,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), + params.b, params.h, params.dk_semaphore, params.dv_semaphore, @@ -301,10 +302,11 @@ template(params, stream); - run_flash_bwd(params, stream); -// }); + BOOL_SWITCH(params.deterministic, Deterministic_, [&] { + static constexpr bool Deterministic = Deterministic_ && kHeadDim < 256; + // run_flash_bwd(params, stream); + run_flash_bwd(params, stream); + }); }); }); } diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969d9..05667698006 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -145,6 +145,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -164,6 +165,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -187,7 +189,9 @@ class FlashAttnFwdCombine { args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.semaphore_to_reset + args.varlen_batch_idx_ptr, + args.semaphore_to_reset, + }; } @@ -203,8 +207,9 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + int const maybe_virtual_batch = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924b4..a2ff25dcd5f 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -35,7 +35,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b8af2977f11..d48a4fd9562 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -57,8 +57,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + static constexpr bool LPT = Is_causal || Is_local; + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -149,14 +151,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, params.dv, sizeof(Element), + params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.num_m_blocks_ptr, + params.varlen_batch_idx_ptr, + params.num_nheads_in_l2_ptr }; - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + if (Varlen && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -189,7 +193,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -205,7 +209,6 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 7093fff32b6..1d810c015ed 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -2,6 +2,7 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ +#include #include "cutlass/fast_math.h" #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" @@ -10,8 +11,35 @@ #include "flash.h" +#include "static_switch.h" + namespace flash { +// Sort in descending order +template +struct PrepareSortOp +{ + __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) + { + return lhs > rhs; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, @@ -19,16 +47,28 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - // int* const num_m_blocks_ptr, + int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, - bool enable_pdl) { + int* const varlen_batch_idx_ptr, + // int* const num_n_blocks_ptr, + int* const num_nheads_in_l2_ptr, + bool enable_pdl, + bool is_causal, + bool packgqa, + int max_kvblocks_in_l2) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; - // Assume that there's only one block in the grid + static constexpr int BLOCK_DIM_X = NumWarps * 32; + static constexpr int ITEMS_PER_THREAD = 1; + static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); + using BlockMergeSort = cub::BlockMergeSort; + __shared__ int total_blocks_smem[kSmemSize]; - // There's only 1 block in the grid, so might as well start launching the main attn kernel + // Allocate shared memory for BlockMergeSort operations + __shared__ typename BlockMergeSort::TempStorage temp_storage; + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } @@ -38,8 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; @@ -50,13 +89,12 @@ __global__ void prepare_varlen_num_blocks_kernel( } else { seqlen = seqlen_q_static; } - seqlen *= qhead_per_khead; + if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; - auto get_num_n_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_n_blocks = [&](int batch_idx) { int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; int seqlen; if (seqused_k) { @@ -83,42 +121,130 @@ __global__ void prepare_varlen_num_blocks_kernel( }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int bidb_start = kNumBatchPerWarp * warp_idx; - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); - - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + int batch_cta_idx_offset = int(blockIdx.x) * 992; + int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; + int batch_idx = lane + bidb_start; + int num_m_blocks = get_num_m_blocks(batch_idx); + int num_n_blocks = get_num_n_blocks(batch_idx); + + auto get_nheads_in_l2 = [&](int n_blocks) { + int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 + : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 + : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 + : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 + : 1; + if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } + return min(nheads_in_l2, num_head); + }; + + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; + } else { + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } + + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + batch_idx = batch_cta_idx_offset + threadIdx.x; + if (batch_idx < num_batch && threadIdx.x < 992) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; + varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } } + } } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl) { - // Only support batch <= 992 (32 warps, each with 31 batches) - int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, enable_pdl); + int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); + int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 + int num_ctas = cutlass::ceil_div(params.b, 31 * 32); + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice + int const element_size = params.is_e4m3 ? 1 : 2; + int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); + int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; + BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { + NUM_WARP_SWITCH(num_warps, NumWarps, [&] { + flash::prepare_varlen_num_blocks_kernel<<>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.varlen_batch_idx_ptr, + // params.num_n_blocks_ptr, + params.num_nheads_in_l2_ptr, + enable_pdl, + params.is_causal, + packgqa, + max_kvblocks_in_l2); + }); + }); } diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index ec34e20eca1..c67ae17969f 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -607,7 +607,8 @@ struct CollectiveMainloopBwdSm90 { seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { + // Though if local and deterministic, still need to increment dq semaphore + if constexpr ((Is_causal || Is_local || Varlen) && !(Is_local && Deterministic)) { if (m_block_max <= m_block_min) { return; } } @@ -626,10 +627,18 @@ struct CollectiveMainloopBwdSm90 { using Barrier = cutlass::GenericBarrier; bool const lane_predicate = cute::elect_one_sync(); int m_block = m_block_min; + constexpr int kBlockM = get<0>(TileShape_MNK{}); + constexpr int kBlockN = get<1>(TileShape_MNK{}); + int n_block_global_max = cute::ceil_div(seqlen_info.seqlen_k, kBlockN); #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + if constexpr(Is_causal) { + int n_block_max_for_m_block = std::min(n_block_global_max, cute::ceil_div((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q, kBlockN)); + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block_max_for_m_block - 1 - n_block); + } else { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } } #pragma unroll for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { @@ -649,7 +658,6 @@ struct CollectiveMainloopBwdSm90 { } } if constexpr (Is_local && Deterministic) { - constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { @@ -930,7 +938,7 @@ struct CollectiveMainloopBwdSm90 { Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..95729edabe2 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -64,6 +64,8 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -80,6 +82,42 @@ _maybe_write, ) +def create_build_config_file(): + CONFIG = { + "build_flags": { + "FLASHATTENTION_DISABLE_BACKWARD": DISABLE_BACKWARD, + "FLASHATTENTION_DISABLE_SPLIT": DISABLE_SPLIT, + "FLASHATTENTION_DISABLE_PAGEDKV": DISABLE_PAGEDKV, + "FLASHATTENTION_DISABLE_APPENDKV": DISABLE_APPENDKV, + "FLASHATTENTION_DISABLE_LOCAL": DISABLE_LOCAL, + "FLASHATTENTION_DISABLE_SOFTCAP": DISABLE_SOFTCAP, + "FLASHATTENTION_DISABLE_PACKGQA": DISABLE_PACKGQA, + "FLASHATTENTION_DISABLE_FP16": DISABLE_FP16, + "FLASHATTENTION_DISABLE_FP8": DISABLE_FP8, + "FLASHATTENTION_DISABLE_VARLEN": DISABLE_VARLEN, + "FLASHATTENTION_DISABLE_CLUSTER": DISABLE_CLUSTER, + "FLASHATTENTION_DISABLE_HDIM64": DISABLE_HDIM64, + "FLASHATTENTION_DISABLE_HDIM96": DISABLE_HDIM96, + "FLASHATTENTION_DISABLE_HDIM128": DISABLE_HDIM128, + "FLASHATTENTION_DISABLE_HDIM192": DISABLE_HDIM192, + "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, + "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, + "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + "FLASH_ATTENTION_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, + "FLASH_ATTENTION_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, + } + } + + with open("flash_attn_config.py", "w") as f: + f.write("# Auto-generated by flash attention 3 setup.py\n") + f.write(f"CONFIG = {repr(CONFIG)}\n") + f.write("\n") + + f.write("def show():\n") + f.write(" from pprint import pprint\n") + f.write(" pprint(CONFIG)\n") + f.write("\n") + def _write_ninja_file(path, cflags, post_cflags, @@ -393,15 +431,23 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) + create_build_config_file() check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") + elif bare_metal_version >= Version("13.0"): + # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/ + cccl_include = os.path.join(CUDA_HOME, "include", "cccl") + for env_var in ["CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"]: + current = os.environ.get(env_var, "") + os.environ[env_var] = cccl_include + (":" + current if current else "") # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. - if bare_metal_version != Version("12.8"): + # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain + if bare_metal_version >= Version("12.3") and bare_metal_version < Version("13.0") and bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", @@ -468,10 +514,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -481,7 +530,18 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all", "diff"] + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -495,6 +555,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" @@ -502,8 +570,20 @@ def nvcc_threads_args(): if DISABLE_BACKWARD: sources_bwd_sm90 = [] sources_bwd_sm80 = [] + + # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version + torch_version = parse(torch.__version__) + target_version = parse("2.9.0.dev20250830") + stable_args = [] + + if torch_version >= target_version: + flash_api_source = "flash_api_stable.cpp" + stable_args = ["-DTORCH_STABLE_ONLY"] # Checks against including unstable Tensor APIs + else: + flash_api_source = "flash_api.cpp" + sources = ( - ["flash_api.cpp"] + [flash_api_source] + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90 + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90 ) @@ -542,7 +622,7 @@ def nvcc_threads_args(): name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + feature_args, + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, @@ -633,7 +713,7 @@ def run(self): "benchmarks", ) ), - py_modules=["flash_attn_interface"], + py_modules=["flash_attn_interface", "flash_attn_config"], description="FlashAttention-3", long_description=long_description, long_description_content_type="text/markdown", diff --git a/hopper/sm90_pipeline_no_cluster.hpp b/hopper/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..1fb805aec1f 100644 --- a/hopper/sm90_pipeline_no_cluster.hpp +++ b/hopper/sm90_pipeline_no_cluster.hpp @@ -39,7 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 5e13b5f93a8..15a7d51364b 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -179,3 +179,26 @@ return __VA_ARGS__(); \ } \ }() + +#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE <= 1) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 109b5fcac00..78a8e7c2cc4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,6 +6,11 @@ import torch import torch.nn.functional as F from torch._C import parse_schema +from torch.testing._internal.optests.generate_tests import ( + safe_fake_check, + safe_schema_check, + safe_aot_autograd_check, +) from einops import rearrange, repeat try: @@ -38,6 +43,8 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" +ENABLE_OPCHECK = os.getenv("FLASH_ATTENTION_ENABLE_OPCHECK", "FALSE") == "TRUE" +ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -48,6 +55,61 @@ + ([256] if not DISABLE_HDIM256 else []) ) +def should_test_backward(args, kwargs): + v = args[2] + num_splits = kwargs.get("num_splits", 1) + dtype = v.dtype + has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True + attention_chunk = kwargs.get("attention_chunk") + dv = v.size(-1) + + if ( + ENABLE_AUTOGRAD_CHECK + and not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and num_splits > 0 # we don't support num_split == 0 on torch.compile yet + ): + return True + return False + + +def should_run_schema_check(args, kwargs): + v = args[2] + if v.dtype == torch.float8_e4m3fn: + return False + return True + + +def should_run_fake_check(args, kwargs): + if 'num_splits' in kwargs: + return kwargs['num_splits'] > 0 + return True + + +def run_opcheck(fn): + def wrapper(*args, **kwargs): + if should_run_schema_check(args, kwargs): + safe_schema_check(fn, args, kwargs) + + if should_run_fake_check(args, kwargs): + safe_fake_check(fn, args, kwargs) + + if should_test_backward(args, kwargs): + # Expensive check + safe_aot_autograd_check(fn, args, kwargs, dynamic=False) + safe_aot_autograd_check(fn, args, kwargs, dynamic=True) + return fn(*args, **kwargs) + return wrapper + + +if ENABLE_OPCHECK: + flash_attn_func = run_opcheck(flash_attn_func) + flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) @@ -55,8 +117,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -75,7 +137,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -107,6 +169,8 @@ def test_flash_attn_output( ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) @@ -121,8 +185,11 @@ def test_flash_attn_output( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -193,7 +260,8 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( + print(f"{pack_gqa = }, {num_splits = }") + out = flash_attn_func( q, k, v, @@ -286,8 +354,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -295,7 +363,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -305,7 +373,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -328,28 +396,38 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = nheads + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -458,9 +536,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # pack_gqa_vals = [False] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1] + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("seqused_q: ", seqused_q) + # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + print(f"{pack_gqa = }, {num_splits = }") + out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, @@ -477,6 +562,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -580,16 +667,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) @@ -597,9 +684,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -669,6 +756,7 @@ def test_flash_attn_kvcache( dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -850,17 +938,21 @@ def test_flash_attn_kvcache( sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + batch_size, + max_seqlen_q if varlen_q else seqlen_q, + seqlen_k if page_size is None else page_table.shape[1] * page_size, + nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, - num_splits=num_splits + num_splits=num_splits, ) else: scheduler_metadata = None @@ -895,7 +987,7 @@ def test_flash_attn_kvcache( rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, - return_softmax_lse=True + return_softmax_lse=True, ) if varlen_q: out = output_pad_fn(out) @@ -1050,7 +1142,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) - out0, lse0 = flash_attn_func(q, k, v, causal=causal) + out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq @@ -1058,9 +1150,9 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): for i in range(1000): torch.random.manual_seed(42) - out, lse = flash_attn_func(q, k, v, causal=causal) + out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) - assert torch.equal(lse, lse0) + # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) diff --git a/hopper/test_flash_attn_bwd_determinism.py b/hopper/test_flash_attn_bwd_determinism.py new file mode 100644 index 00000000000..b443c8948d4 --- /dev/null +++ b/hopper/test_flash_attn_bwd_determinism.py @@ -0,0 +1,706 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + +from flash_attn_interface import _flash_attn_backward + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +# deterministic mode not supported for hdim 256 +DISABLE_HDIM256 = True + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + # (4224, 4224), + # (8192, 8192), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out, softmax_lse = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + return_attn_probs=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq, dk, dv, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dv2 = torch.empty_like(dv) + dq2, dk2, dv2, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq2, + dk2, + dv2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + print(f"✅ Iteration {i} passed!") + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (1024, 1024), + (2048, 2048), + (4096, 4096), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, +): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # batch_size = 2 + # nheads = 1 + # nheads_kv = nheads + + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + print("cu_seqlens_q: ", cu_seqlens_q) + print("cu_seqlens_k: ", cu_seqlens_k) + print("seqused_q: ", seqused_q) + print("seqused_k: ", seqused_k) + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out_unpad, softmax_lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad = torch.empty_like(q_unpad) + dk_unpad = torch.empty_like(k_unpad) + dv_unpad = torch.empty_like(v_unpad) + dq_unpad, dk_unpad, dv_unpad, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad, + dk_unpad, + dv_unpad, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + print(dq_unpad.shape) + print(dk_unpad.shape) + print(dv_unpad.shape) + + print(dq.shape) + print(dk.shape) + print(dv.shape) + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq_unpad2 = torch.empty_like(q_unpad) + dk_unpad2 = torch.empty_like(k_unpad) + dv_unpad2 = torch.empty_like(v_unpad) + dq_unpad2, dk_unpad2, dv_unpad2, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad2, + dk_unpad2, + dv_unpad2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + + dq2 = dq_pad_fn(dq_unpad2) + dk2 = dk_pad_fn(dk_unpad2) + dv2 = dk_pad_fn(dv_unpad2) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk2.masked_fill_(k_zero_masking, 0.0) + dv2.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq2.masked_fill_(q_zero_masking, 0.0) + + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + + print(f"✅ Iteration {i} passed!") \ No newline at end of file diff --git a/hopper/test_torch_compile_and_export.py b/hopper/test_torch_compile_and_export.py new file mode 100644 index 00000000000..53beef46340 --- /dev/null +++ b/hopper/test_torch_compile_and_export.py @@ -0,0 +1,73 @@ +import torch +from flash_attn_interface import flash_attn_func +from torch import nn + + +class EfficienctMultiHeadAttention(nn.Module): + def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True): + super().__init__() + assert embed_size % num_heads == 0, f"{embed_size=} {num_heads=}" + + self.embed_size = embed_size + self.num_heads = num_heads + self.head_dim = embed_size // num_heads + self.use_flash_attn = use_flash_attn and (flash_attn_func is not None) + + self.qkv_proj = nn.Linear(embed_size, 3 * embed_size) + self.out_proj = nn.Linear(embed_size, embed_size) + self.dropout = dropout + + def forward(self, x, attention_mask=None): + N, seq_length, _ = x.shape + + qkv = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(N, seq_length, self.num_heads, self.head_dim) + k = k.view(N, seq_length, self.num_heads, self.head_dim) + v = v.view(N, seq_length, self.num_heads, self.head_dim) + + if self.use_flash_attn and attention_mask is None: + out = flash_attn_func( + q, k, v + ) + out = out.reshape(N, seq_length, self.embed_size) + out = self.out_proj(out) + return out + + +def create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16): + model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16() + input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16() + return model, input_tensor + + +def test_export_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + loss = expected.sum() + loss.backward() + + ep = torch.export.export(model, (input_tensor,)) + got = ep.module()(input_tensor,) + assert torch.equal(expected, got) + + loss_2 = got.sum() + loss_2.backward() + + assert torch.equal(loss, loss_2) + + +def test_compile_and_package_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + + exported = torch.export.export(model, (input_tensor,)) + torch._inductor.aoti_compile_and_package( + exported, + package_path="model.pt2", + ) + + compiled_model = torch._inductor.package.load_package("model.pt2") + out = compiled_model(input_tensor,) + assert torch.equal(expected, out) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1f90f66adc2..241eaed40f8 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -24,8 +24,11 @@ struct TileSchedulerArguments { int* const tile_count_semaphore = nullptr; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + // int const* const num_n_blocks_ptr = nullptr; + int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -248,7 +251,7 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; + long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead @@ -361,6 +364,7 @@ class DynamicPersistentTileScheduler { /////////////////////////////////////////////////////////////////////////////// +template class SingleTileBwdLPTScheduler { public: @@ -370,18 +374,21 @@ class SingleTileBwdLPTScheduler { // Device side kernel params struct Params { int const total_blocks; - cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const block_divmod, head_divmod; cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; + int const seqlen; + int const* const cu_seqlens; + int const* const seqused; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k - int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; - int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float); - int const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); + long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum // Swizzle is the size of each "section". Round swizzle to a power of 2 // Need to be careful about the case where only one head will fit @@ -398,7 +405,8 @@ class SingleTileBwdLPTScheduler { cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), - (args.num_head * args.num_batch) / swizzle}; + (args.num_head * args.num_batch) / swizzle, + args.seqlen, !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; } static dim3 @@ -407,28 +415,19 @@ class SingleTileBwdLPTScheduler { } struct WorkTileInfo { - int tile_idx; + int block; + int bidh; + int bidb; CUTLASS_DEVICE bool is_valid(Params const& params) const { - return tile_idx < params.total_blocks; + return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - int block, bidh, bidb; - int l2_mod, bidhb, bidhb_residual; - bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - if (bidhb < params.num_hb_quotient) { - block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); - } else { - block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); - } - bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); return {block, bidh, bidb, 0 /*split_idx*/}; } @@ -441,7 +440,33 @@ class SingleTileBwdLPTScheduler { CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { - return {int(blockIdx.x)}; + int tile_idx = blockIdx.x; + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + bool is_valid_tile = true; + int num_blocks; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[bidb] + : (params.cu_seqlens ? params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] : params.seqlen); + num_blocks = cute::ceil_div(seqlen, Int{}); + is_valid_tile = block < num_blocks; + } else { + num_blocks = params.block_divmod.divisor; + } + if constexpr (SPT) { + block = num_blocks - block - 1; + } + return {block, bidh, is_valid_tile ? bidb : -1}; } CUTLASS_DEVICE @@ -456,14 +481,15 @@ class SingleTileBwdLPTScheduler { CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {params.total_blocks}; + return {0, 0, -1}; } }; /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -482,13 +508,17 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + // int const max_kvblocks_in_l2; cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; int const* const cu_seqlens; int const* const seqused; - // int* const num_m_blocks_ptr; int const* const num_splits_dynamic_ptr; + int const* const num_m_blocks_ptr; + int const* const varlen_batch_idx_ptr; + // int const* const num_n_blocks_ptr; + int const* const num_nheads_in_l2_ptr; }; static Params @@ -498,13 +528,20 @@ class VarlenDynamicPersistentTileScheduler { assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; + // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + // max_kvblocks_in_l2, cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr, + args.num_m_blocks_ptr, + args.varlen_batch_idx_ptr, + // aras.num_n_blocks_ptr, + args.num_nheads_in_l2_ptr}; } static dim3 @@ -525,8 +562,15 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { + auto get_actual_batch = [&](int virtual_batch) { + if constexpr(Prepared && Sort) { + return params.varlen_batch_idx_ptr[virtual_batch]; + } else { + return virtual_batch; + } + }; if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; + return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift @@ -540,7 +584,7 @@ class VarlenDynamicPersistentTileScheduler { // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } - return {block, bidh_actual, bidb, split_idx}; + return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; @@ -554,31 +598,39 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; + if constexpr (Prepared) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; + } else { + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlockM) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlockM) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; }; auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; + bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; + if constexpr (!Split) { + return is_valid ? 1 : 0; + } else if constexpr(Prepared) { + return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; + } else { + return is_valid ? params.nsplits_divmod.divisor : 0; + } }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane @@ -589,12 +641,14 @@ class VarlenDynamicPersistentTileScheduler { // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } + // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // if constexpr (Split) { + // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + // } + // NEW: current_work.tile_idx holds group_start_tile for starting batch + int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); @@ -626,27 +680,81 @@ class VarlenDynamicPersistentTileScheduler { bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int mh_block = next_tile_idx - group_start_tile; + int block, bidh; + if constexpr (LPT) { + if (!Split || num_splits == 1) { + // NOTE: code for computing nheads_in_l2 directly left as reference + // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; + // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks + // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); + // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } + // nheads_in_l2 = min(nheads_in_l2, params.num_head); + auto get_nheads_in_l2 = [&](int batch_idx) { + if constexpr(Prepared) { + return params.num_nheads_in_l2_ptr[batch_idx]; + } else { + return !PackGQA ? params.qhead_per_khead : 1; + } + }; + int nheads_in_l2 = get_nheads_in_l2(bidb); + int mh_in_l2 = nheads_in_l2 * num_m_blocks; + int section_idx = mh_block / mh_in_l2; + int l2_mod = mh_block - section_idx * mh_in_l2; + // tail section + int nheads_remainder = params.num_head - section_idx * nheads_in_l2; + int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; + block = l2_mod / nheads_in_this_section; + int bidh_residual = l2_mod - block * nheads_in_this_section; + bidh = section_idx * nheads_in_l2 + bidh_residual; + if constexpr(Split) { + // remember to set num_splits = 1 in work tile + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } else { + // NOTE: leave traverse heads first version for reference + // block = params.head_divmod.divmod(bidh, mh_block); + // if constexpr (Split) { + // int split_idx = block / num_m_blocks; + // block = block - split_idx * num_m_blocks; + // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // bidh = reinterpret_cast(bidh_packed); + // } + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } + block = num_m_blocks - 1 - block; + } else { + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - bidh = reinterpret_cast(bidh_packed); } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; + return {group_start_tile, block, bidh, bidb}; } template diff --git a/hopper/tile_size.h b/hopper/tile_size.h index e6cb31515c7..8353542c477 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -21,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( return {128, 96, true, false}; } else { // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen @@ -29,8 +29,9 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - // {128, 192, false, false} and {192, 128, false, true} are quite good too + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem diff --git a/setup.py b/setup.py index a7f15a99724..9b1bd10088a 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";") + return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") def get_platform(): @@ -94,6 +94,59 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version +def add_cuda_gencodes(cc_flag, archs, bare_metal_version): + """ + Adds -gencode flags based on nvcc capabilities: + - sm_80/90 (regular) + - sm_100/120 on CUDA >= 12.8 + - Use 100f on CUDA >= 12.9 (Blackwell family-specific) + - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) + - Embed PTX for newest arch for forward compatibility + """ + # Always-regular 80 + if "80" in archs: + cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + + # Hopper 9.0 needs >= 11.8 + if bare_metal_version >= Version("11.8") and "90" in archs: + cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] + + # Blackwell 10.x requires >= 12.8 + if bare_metal_version >= Version("12.8"): + if "100" in archs: + # CUDA 12.9 introduced "family-specific" for Blackwell (100f) + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if "120" in archs: + # sm_120 is supported in CUDA 12.8/12.9+ toolkits + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + else: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 + if "110" in archs: + if bare_metal_version >= Version("13.0"): + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] + else: + # Provide Thor support for CUDA 12.9 via sm_101 + if bare_metal_version >= Version("12.8"): + cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] + # else: no Thor support in older toolkits + + # PTX for newest requested arch (forward-compat) + numeric = [a for a in archs if a.isdigit()] + if numeric: + newest = max(numeric, key=int) + cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + + return cc_flag + + def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) @@ -120,6 +173,18 @@ def check_if_rocm_home_none(global_option: str) -> None: ) +def detect_hipify_v2(): + try: + from torch.utils.hipify import __version__ + from packaging.version import Version + if Version(__version__) >= Version("2.0.0"): + return True + except Exception as e: + print("failed to detect pytorch hipify version, defaulting to version 1.0.0 behavior") + print(e) + return False + + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] @@ -132,7 +197,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100"] # Validate if each element in archs is in allowed_archs assert all( @@ -175,26 +240,45 @@ def validate_and_update_archs(archs): "FlashAttention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) - - if "80" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: - if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") + # Build -gencode (regular + PTX + family-specific 'f' when available) + add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) + else: + # No nvcc present; warnings already emitted above + pass # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", + ] + + compiler_c17_flag=["-O3", "-std=c++17"] + # Add Windows-specific flags + if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': + nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) + compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] + ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", @@ -274,30 +358,8 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - # "--ptxas-options=-v", - # "--ptxas-options=-O2", - # "-lineinfo", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - # "-DFLASHATTENTION_DISABLE_DROPOUT", - # "-DFLASHATTENTION_DISABLE_ALIBI", - # "-DFLASHATTENTION_DISABLE_SOFTCAP", - # "-DFLASHATTENTION_DISABLE_UNEVEN_K", - # "-DFLASHATTENTION_DISABLE_LOCAL", - ] - + cc_flag - ), + "cxx": compiler_c17_flag, + "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", @@ -319,10 +381,11 @@ def validate_and_update_archs(archs): if not os.path.exists("./build"): os.makedirs("build") - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2"], check=True) + optdim = os.getenv("OPT_DIM", "32,64,128,256") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 @@ -357,6 +420,12 @@ def validate_and_update_archs(archs): f"build/fmha_*wd*.cpp" ) + # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro, + # we must replace the incorrect APIs. + maybe_hipify_v2_flag = [] + if detect_hipify_v2(): + maybe_hipify_v2_flag = ["-DHIPIFY_V2"] + rename_cpp_to_cu(sources) renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", @@ -399,8 +468,8 @@ def validate_and_update_archs(archs): cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": cc_flag + generator_flag, + "cxx": ["-O3", "-std=c++17"] + generator_flag + maybe_hipify_v2_flag, + "nvcc": cc_flag + generator_flag + maybe_hipify_v2_flag, } include_dirs = [ diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py new file mode 100644 index 00000000000..d1ac5318004 --- /dev/null +++ b/tests/cute/test_block_sparsity.py @@ -0,0 +1,422 @@ +"""Tests for block sparsity computation in flash attention.""" + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask + +from flash_attn.cute.mask_definitions import get_mask_pair +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + + +def _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity and return torch tensors.""" + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + blocksparse_tensors, torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + use_fast_sampling=use_fast_sampling, + ) + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = torch_tensors + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + + +def _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, +): + """Compare block sparsity against reference. Returns (all_match, error_msg).""" + if not isinstance(mask_block_cnt, torch.Tensor): + return False, f"mask_block_cnt is not a tensor: {type(mask_block_cnt)}" + + n_blocks_q = mask_block_cnt.shape[2] + mask_cnt_match = torch.all(mask_block_cnt == mask_block_cnt_ref).item() + full_cnt_match = torch.all(full_block_cnt == full_block_cnt_ref).item() + + if not mask_cnt_match or not full_cnt_match: + error_msg = [] + if not mask_cnt_match: + error_msg.append("Mask counts mismatch") + diff = (mask_block_cnt != mask_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {mask_block_cnt[b, h, m].item()}, " + f"expected {mask_block_cnt_ref[b, h, m].item()}" + ) + if not full_cnt_match: + error_msg.append("Full counts mismatch") + diff = (full_block_cnt != full_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {full_block_cnt[b, h, m].item()}, " + f"expected {full_block_cnt_ref[b, h, m].item()}" + ) + return False, "\n".join(error_msg) + + # Compare indices + for b in range(batch_size): + for h in range(nheads): + for m in range(n_blocks_q): + num_mask = mask_block_cnt[b, h, m].item() + num_full = full_block_cnt[b, h, m].item() + + if num_mask > 0: + mask_indices = mask_block_idx[b, h, m, :num_mask].sort()[0] + mask_indices_ref = mask_block_idx_ref[b, h, m, :num_mask].sort()[0] + if not (mask_indices == mask_indices_ref).all(): + return False, f"Mask indices mismatch at [{b},{h},{m}]" + + if num_full > 0: + full_indices = full_block_idx[b, h, m, :num_full].sort()[0] + full_indices_ref = full_block_idx_ref[b, h, m, :num_full].sort()[0] + if not (full_indices == full_indices_ref).all(): + return False, f"Full indices mismatch at [{b},{h},{m}]" + + return True, "" + + +# Test configurations +SEQLEN_PAIRS = [ + # Small aligned + (64, 64), + (128, 128), + (256, 256), + (512, 512), + # Rectangular + (128, 256), + (256, 128), + (512, 256), + (256, 512), + # Large aligned + (1024, 1024), + (2048, 2048), + (4096, 4096), + # Large unaligned + (1000, 1000), + (2000, 2000), + (4000, 4000), + # Edge cases with unaligned seqlens + (113, 203), + (127, 127), + (129, 129), + (255, 255), + (257, 257), + (1023, 1023), + (1025, 1025), + (2047, 2047), + (2049, 2049), +] +TILE_SIZES = [ + # Standard powers of 2 + (32, 32), + (64, 64), + (128, 128), + (256, 256), + # Rectangular + (32, 64), + (64, 32), + (64, 128), + (128, 64), + (128, 256), + (256, 128), + # Unusual sizes + (40, 40), + (48, 48), + (96, 96), + (112, 112), + (32, 128), + (128, 32), + (40, 96), + (96, 40), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize("tile_m,tile_n", TILE_SIZES) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal"]) +def test_fixed_length_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name +): + """Test fixed-length masks.""" + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_parameterized_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size +): + """Test parameterized masks.""" + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + ) + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k,tile_m,tile_n", + [ + (1, 1, 64, 64), + (63, 63, 64, 64), + (65, 65, 64, 64), + (129, 129, 128, 128), + (100, 200, 64, 128), + ], +) +def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): + """Test edge cases with unaligned dimensions.""" + batch_size, nheads = 1, 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + "causal", + ) + ) + + _, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): + """Test fast sampling mode (5-point sampling).""" + batch_size = 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + use_fast_sampling=True, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index bc41a56d813..fe1d18afb6d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -2,33 +2,50 @@ import math import itertools +import os import pytest import torch from einops import rearrange, repeat + try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None -# from padding import pad_input, unpad_input -from flash_attn.utils.testing import attention_ref, generate_qkv, generate_random_padding_mask -from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, +) + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +TEST_BWD_ONLY = False +VERBOSE = True # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -38,12 +55,17 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), + (3, 3), + (64, 32), (64, 128), + (128, 128), (128, 192), (256, 256), (239, 1), @@ -60,47 +82,100 @@ (1024, 1024), (1023, 1024), (1024, 1023), + (2048, 2048), (4096, 4096), (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, + seqlen_k, + d, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): - if causal and seqlen_k < seqlen_q: - pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + local = local_enum > 0 + if local and causal: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) + q_ref = q_ref * softcap / 4 q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -113,10 +188,13 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - softcap=softcap + learnable_sink=learnable_sink, + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -126,9 +204,12 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -137,13 +218,13 @@ def test_flash_attn_output( # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() - # if qv is not None: - # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() @@ -151,10 +232,11 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1, 3] + # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -163,11 +245,13 @@ def test_flash_attn_output( causal=causal, # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - # window_size=window_size, + window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, - # pack_gqa=pack_gqa, - # num_splits=num_splits + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -177,7 +261,9 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -185,6 +271,9 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 + and dv == d + and learnable_sink is None + # and False ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -199,9 +288,12 @@ def test_flash_attn_output( # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -215,29 +307,63 @@ def test_flash_attn_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") + # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -246,6 +372,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", @@ -273,38 +400,86 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) - batch_size = 9 if seqlen_q <= 2048 else 2 + batch_size = 49 if seqlen_q <= 1024 else 7 nheads = 6 # batch_size = 1 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -315,7 +490,11 @@ def test_flash_attn_varlen_output( # TODO: test zero_lengths key_padding_mask = generate_random_padding_mask( # seqlen_k, batch_size, device, mode="random", zero_lengths=True - seqlen_k, batch_size, device, mode="random", zero_lengths=False + seqlen_k, + batch_size, + device, + mode="random", + zero_lengths=False, ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @@ -339,6 +518,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) + if causal or local: + key_padding_mask = query_padding_mask + ( q_unpad, k_unpad, @@ -357,9 +539,20 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] out_ref, attn_ref = attention_ref( q_ref, k_ref, @@ -368,10 +561,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - softcap=softcap + learnable_sink=learnable_sink, + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -381,9 +577,12 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -400,10 +599,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # pack_gqa_vals = [False, True, None] pack_gqa_vals = [False] - num_splits_vals = [1] + # num_splits_vals = [1, 3] + # SplitKV is not supported for hdim >= 192 + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad, lse = flash_attn_varlen_func( q_unpad, @@ -412,16 +612,18 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, # max_seqlen_k, - seqused_q=seqused_q, - seqused_k=seqused_k, - max_seqlen_q=max_seqlen_q, + # seqused_q=seqused_q, + # seqused_k=seqused_k, causal=causal, # qv=qv_unpad, # q_descale=q_descale, # k_descale=k_descale, v_descale=v_descale, - # window_size=window_size, + window_size=window_size, # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -434,14 +636,17 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 + and dv == d + and not has_learnable_sink and False ): g_unpad = torch.randn_like(out_unpad) @@ -469,7 +674,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # deterministic, # 0, # sm_margin # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) @@ -493,9 +700,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -510,9 +718,669 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +# @pytest.mark.parametrize("page_size", [None, 128]) +# @pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False, True]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # # (1, 128 * 1024), + # # (16, 128 * 1024), + # (128, 128), + # (256, 512), # To test appending KV with more than 1 block + # (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + has_learnable_sink, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + # has_qv = d == 64 and dv >= 256 + has_qv = False + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + # num_splits_vals = [1, 0] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + # precompute_metadata_vals = [False, True] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + # if precompute_metadata: + # scheduler_metadata = get_scheduler_metadata( + # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + # max_seqlen_k_new=seqlen_new, page_size=page_size, + # causal=causal, window_size=window_size, attention_chunk=attention_chunk, + # num_splits=num_splits + # ) + # else: + # scheduler_metadata = None + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + # out, lse, *rest = flash_attn_with_kvcache( + out, lse, *rest = flash_attn_varlen_func( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + # k if not new_kv or not varlen_q else k_unpad, + # v if not new_kv or not varlen_q else v_unpad, + # qv=qv if not varlen_q else qv_unpad, + # rotary_cos=cos, + # rotary_sin=sin, + seqused_k=cache_seqlens, + # cache_batch_idx=cache_batch_idx, + # cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, + # rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + # attention_chunk=attention_chunk, + # rotary_interleaved=rotary_interleaved, + # scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + # return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ( + (out - out_ref).abs().max().item() + <= multiple * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py new file mode 100644 index 00000000000..520cf6466a7 --- /dev/null +++ b/tests/cute/test_flash_attn_race_condition.py @@ -0,0 +1,339 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools +import os + +import pytest +import torch + +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, + _flash_attn_bwd, +) + + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (4224, 4224), + (2000, 4000), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, + seqlen_k, + d, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 + pack_gqa_vals = [False] + # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and dv == d + and learnable_sink is None + # and False + ): + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 20_000 + for i in range(num_iters): + dq2, dk2, dv2, = _flash_attn_bwd( + q, k, v, out, g, lse, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + deterministic=True, + ) + + diff_dq = (dq - dq2).abs() + max_idx = diff_dq.argmax() + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}") + + diff_dk = (dk - dk2).abs() + max_idx = diff_dk.argmax() + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}") + + diff_dv = (dv - dv2).abs() + max_idx = diff_dv.argmax() + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}") + + # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") + # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") + # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") + # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") + # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") + # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + + assert torch.equal(dq, dq2) + assert torch.equal(dk, dk2) + assert torch.equal(dv, dv2) + + print(f"✅ Iteration {i} passed!") + diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py new file mode 100644 index 00000000000..53d907eed94 --- /dev/null +++ b/tests/cute/test_flash_attn_varlen.py @@ -0,0 +1,313 @@ +import itertools +from typing import Optional +from einops import rearrange +import pytest + +import torch +import torch.nn.functional as F +from flash_attn.cute import flash_attn_varlen_func + +@pytest.mark.parametrize("B", [1, 7, 20]) +@pytest.mark.parametrize("H", [1, 4, 6]) +@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("min_seq_len", [1, 32, 128]) +@pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("softmax_scale", [None, 0.1]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +def test_varlen( + B, + H, + D, + min_seq_len, + max_seq_len, + causal, + softmax_scale, + dtype, + mha_type, +): + if min_seq_len > max_seq_len: + pytest.skip("Skipping min_seq_len > max_seq_len") + + q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( + batch_size=B, + n_heads=H, + d_head=D, + min_len=min_seq_len, + max_len=max_seq_len, + mha_type=mha_type, + dtype=dtype + ) + + # SM100 (Blackwell) backward pass doesn't support varlen yet + compute_capability = torch.cuda.get_device_capability()[0] + skip_backward = (compute_capability == 10) + + ok = check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + total_q=total_q, total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + skip_backward=skip_backward, + ) + assert ok + +def check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + total_q=None, + total_k=None, + softmax_scale=None, + causal=True, + mha_type='mha', + softcap=0.0, + atol=3e-2, + rtol=3e-2, + skip_backward=False, +): + assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" + + def clone_like(t): + c = t.clone().detach().requires_grad_(True) + return c + + q_fa, k_fa, v_fa = map(clone_like, (q, k, v)) + q_t, k_t, v_t = map(clone_like, (q, k, v)) + + if cu_seqlens_q is not None: + cu_seqlens_q_fa = cu_seqlens_q.clone() + cu_seqlens_q_t = cu_seqlens_q.clone() + else: + cu_seqlens_q_fa = None + cu_seqlens_q_t = None + + if cu_seqlens_k is not None: + cu_seqlens_k_fa = cu_seqlens_k.clone() + cu_seqlens_k_t = cu_seqlens_k.clone() + else: + cu_seqlens_k_fa = None + cu_seqlens_k_t = None + + out_fa, lse_fa = flash_attn_varlen_func( + q_fa, k_fa, v_fa, + cu_seqlens_q=cu_seqlens_q_fa, + cu_seqlens_k=cu_seqlens_k_fa, + seqused_q=seqused_q, + seqused_k=seqused_k, + softmax_scale=(1.0 / q.shape[-1]**0.5) if softmax_scale is None else softmax_scale, + causal=causal, + window_size=(None, None), + learnable_sink=None, + softcap=softcap, + pack_gqa=None, + ) + + out_t = torch_flash_ref( + q_t, k_t, v_t, + cu_seqlens_q=cu_seqlens_q_t, + cu_seqlens_k=cu_seqlens_k_t, + seqused_q=seqused_q, + seqused_k=seqused_k, + total_q=total_q, + total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + ) + + + ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol) + if not ok_fwd: + return False + + # Skip backward if not supported (e.g., SM100 varlen) + if skip_backward: + return True + + # Use the same upstream gradient to compare backward paths + grad_out = torch.randn_like(out_fa) + + grad_fa = clone_like(grad_out) + grad_t = clone_like(grad_out) + + # Cute bwd + out_fa.backward(grad_fa, retain_graph=False) + dq_fa, dk_fa, dv_fa = q_fa.grad, k_fa.grad, v_fa.grad + + # Ref bwd + out_t.backward(grad_t, retain_graph=False) + dq_t, dk_t, dv_t = q_t.grad, k_t.grad, v_t.grad + + # mean_ok_q = _stats("dQ", dq_fa, dq_t, atol=atol, rtol=rtol) + # mean_ok_k = _stats("dK", dk_fa, dk_t, atol=atol, rtol=rtol) + # mean_ok_v = _stats("dV", dv_fa, dv_t, atol=atol, rtol=rtol) + + # return mean_ok_q and mean_ok_k and mean_ok_v + + ok_q = torch.allclose(dq_fa.float(), dq_t.float(), atol=atol, rtol=rtol) + ok_k = torch.allclose(dk_fa.float(), dk_t.float(), atol=atol, rtol=rtol) + ok_v = torch.allclose(dv_fa.float(), dv_t.float(), atol=atol, rtol=rtol) + # print(f"Close? dQ={ok_q}, dK={ok_k}, dV={ok_v}") + return ok_q and ok_k and ok_v + +def generate_varlen_args( + batch_size=8, + n_heads=16, + d_head=128, + min_len=32, + max_len=64, + mha_type="mha", + dtype = torch.bfloat16, +): + + torch.manual_seed(0) + device = "cuda" + + assert mha_type in ["mha", "mqa", "gqa"] + + lens_q = torch.randint(low=min_len, high=max_len + 1, size=(batch_size,)) + lens_k = lens_q.clone() + + cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32), lens_q.cumsum(0)]) + cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32), lens_k.cumsum(0)]) + + total_q = cu_seqlens_q[-1] + total_k = cu_seqlens_k[-1] + + cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) + cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) + + if mha_type == "gqa": + H = 3 * n_heads + H_kv = n_heads + elif mha_type == "mha": + H = H_kv = n_heads + else: # MQA + H = n_heads + H_kv = 1 + + d_head_v = d_head + + q = torch.randn(total_q, H, d_head, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(total_k, H_kv, d_head, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(total_k, H_kv, d_head_v, device=device, dtype=dtype, requires_grad=True) + + return q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k + +# Simple for loop over batch dim implementation +def torch_flash_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_k: torch.Tensor = None, + total_q: int = 0, + total_k: int = 0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs + ): + + """ + q: (total_q, H, d) if cu_seqlens_q is not None, otherwise (B, L, H, d) + k: (total_k, H_kv, d) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d) + v: (total_k, H_kv, d_v) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d_v) + cu_seqlens_q: (B+1,) int32, cumulative + cu_seqlens_k: (B+1,) int32, cumulative + + seqused_q: (B+1,) int32 + seqused_k: (B+1,) int32 + Returns: + out packed like q: (total_q, H, d_v) + """ + + if cu_seqlens_q is not None: + assert cu_seqlens_q.dim() == 1 + assert total_q == q.shape[0] + assert q.dim() == 3 + H = q.shape[1] + B = cu_seqlens_q.shape[0] - 1 + else: + assert q.dim() == 4 + H = q.shape[2] + B = q.shape[0] + + if cu_seqlens_k is not None: + assert cu_seqlens_k.dim() == 1 + assert total_k == k.shape[0] == v.shape[0] + assert k.dim() == v.dim() == 3 + H_kv = k.shape[1] + B_kv = cu_seqlens_k.shape[0] - 1 + else: + assert k.dim() == v.dim() == 4 + assert k.shape[0] == v.shape[0] + H_kv = k.shape[2] + B_kv = k.shape[0] + + d = q.shape[-1] + d_v = v.shape[-1] + + assert H_kv == v.shape[-2] + assert d == k.shape[-1] + assert B == B_kv + + assert q.device == k.device == v.device + assert q.is_floating_point() and k.is_floating_point() and v.is_floating_point() + + device = q.device + dtype = q.dtype + + hcseq_q = cu_seqlens_q.to(device='cpu') + hcseq_k = cu_seqlens_k.to(device='cpu') + + outs = [] + for b in range(B): + if hcseq_q is not None: + q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) + qb = q[q_start:q_end] + else: + qb = q[b] + + if hcseq_k is not None: + k_start, k_end = int(hcseq_k[b]), int(hcseq_k[b+1]) + kb = k[k_start:k_end] + vb = v[k_start:k_end] + else: + kb = k[b] + vb = v[b] + + qb = qb.permute(1, 0, 2).unsqueeze(0) + kb = kb.permute(1, 0, 2).unsqueeze(0) + vb = vb.permute(1, 0, 2).unsqueeze(0) + + ob = F.scaled_dot_product_attention( + qb, kb, vb, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + enable_gqa=H_kv!=H + ) + + ob = ob.squeeze(0).permute(1, 0, 2).contiguous() + outs.append(ob) + + if cu_seqlens_q is not None: + out = torch.cat(outs, dim=0).to(device=device, dtype=dtype) + else: + out = torch.stack(outs, dim=0).to(device=device, dtype=dtype) + return out + +@torch.no_grad() +def _stats(name, a, b, atol, rtol): + diff = (a - b).float() + mean_abs = diff.abs().mean().item() + mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) + print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") + return mean_abs < atol and mean_rel < rtol \ No newline at end of file diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py new file mode 100644 index 00000000000..9c2db48f22b --- /dev/null +++ b/tests/cute/test_mask_mod.py @@ -0,0 +1,515 @@ +# mask mod test script +# REFACTORED to use _flash_attn_fwd as the kernel entrypoint +# +# Test Organization: +# - test_static_masks: Fast tests for masks that don't need per-seqlen compilation +# (identity, document, block_diagonal, etc.) with comprehensive seqlen coverage +# - test_parameterized_masks: Slower tests for masks that require recompilation per +# seqlen pair (causal, block_causal, sliding_window) with reduced seqlen coverage +# +# Usage: +# pytest test_mask_mod.py::test_static_masks # Run only fast tests +# pytest test_mask_mod.py::test_parameterized_masks # Run only slow tests +# pytest test_mask_mod.py # Run all tests + +import math +from typing import Optional + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +import torch.nn.functional as F + +from flash_attn.cute.interface import _flash_attn_fwd +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch +from flash_attn.cute.mask_definitions import ( + get_mask_pair, + STATIC_MASKS, + random_doc_id_tensor, +) +from flash_attn.cute.testing import attention_ref +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + + +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + + yield + + torch._dynamo.reset() + torch.cuda.empty_cache() + +def create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype +): + device = "cuda" + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype) + k = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype + ) + out = torch.empty( + batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype + ) + lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) + + return { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + +def compute_reference_flash_attn(tensors, causal, window_size, dtype_ref, upcast=True): + """Compute reference using FlashAttention's attention_ref function""" + q = tensors["q"].to(dtype_ref) + k = tensors["k"].to(dtype_ref) + v = tensors["v"].to(dtype_ref) + + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=causal, + window_size=window_size, + upcast=upcast, + reorder_ops=False, + ) + + return out_ref + + +def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tuple[int, int]] = None): + """Compute reference using flex_attention for custom mask_mods""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].transpose(1, 2) + k = tensors["k"].transpose(1, 2) + v = tensors["v"].transpose(1, 2) + + if nheads != nheads_kv: + repeat_factor = nheads // nheads_kv + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(headdim) + + # Handle identity (no masking) case + if mask_mod_flex is None: + out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) + return out_ref.transpose(1, 2).contiguous() + + block_mask_kwargs = {} + if block_size is not None: + block_mask_kwargs["BLOCK_SIZE"] = block_size + + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device=q.device, + **block_mask_kwargs, + ) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) + return out_ref.transpose(1, 2).contiguous() + + +SEQLEN_PAIRS_COMPREHENSIVE = [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), +] + +SEQLEN_PAIRS_SMOKE = [ + (128, 128), + (256, 256), + (113, 203), + (1024, 1024), + (128, 8192) +] + + +def _run_mask_test( + seqlen_q, + seqlen_k, + nheads, + kv_mode, + headdim, + dtype, + mask_name, + window_size, + window_left, + window_right, + tile_m, + tile_n, + use_block_sparsity, +): + torch.manual_seed(42) + + if mask_name == "sliding_window": + assert window_size is not None, ( + "window_size must be specified for sliding_window" + ) + if seqlen_q > seqlen_k: + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" + ) + + # Determine nheads_kv based on mode + if kv_mode == "mha": + nheads_kv = nheads + elif kv_mode == "gqa": + nheads_kv = nheads // 2 + elif kv_mode == "mqa": + nheads_kv = 1 + else: + raise ValueError(f"Unknown kv_mode: {kv_mode}") + + batch_size = 1 + headdim_v = headdim + + aux_tensors_arg = None + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + if mask_name == "document": + doc_len = max(seqlen_q, seqlen_k) + doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( + dtype=torch.int32, device="cuda" + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + elif mask_name == "ima": + bias_threshold = (seqlen_k // 4) * 3 + bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device="cuda") + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): + return original_flex_mask(b, h, q_idx, kv_idx, bias) + + aux_tensors_arg = [bias] + causal = False + + if causal and seqlen_k < seqlen_q: + pytest.skip("causal masking requires seqlen_k >= seqlen_q") + + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype + ) + + # Compute block sparsity for mask_mod + if COMPUTE_CAPABILITY == 10: + sparse_tile_m = 2 * tile_m + else: + sparse_tile_m = tile_m + + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + _, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = bm.as_tuple() + + softmax_scale = 1.0 / math.sqrt(headdim) + + # if full_cnt is not None: + # print(f"Block sparsity info for {mask_name}:") + # print(f" full_cnt shape: {full_cnt.shape}") + # print(f" full_idx shape: {full_idx.shape}") + # print(f" mask_cnt shape: {mask_cnt.shape}") + # print(f" mask_idx shape: {mask_idx.shape}") + # print(f" full_cnt: {full_cnt}") + # print(f" full_idx: {full_idx}") + # print(f" mask_cnt: {mask_cnt}") + # print(f" mask_idx: {mask_idx}") + # if full_cnt[0,0,0] > 0: + # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") + # if mask_cnt[0,0,0] > 0: + # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") + block_sparse_mask = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) if use_block_sparsity else None + + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=causal, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + m_block_size=tile_m, + n_block_size=tile_n, + num_threads=384, + pack_gqa=False, + _compute_capability=None, + score_mod=None, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + out_cute = out_tuple[0] + tensors_fp32 = { + k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v + for k, v in tensors.items() + } + + block_size = (tile_m, tile_n) + out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) + out_pt = out_ref.clone() + + # Check for invalid values + assert out_cute.shape == out_ref_fp32.shape == out_ref.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + # Compute numerical tolerance (matching flash attention tests) + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + ref_error = (out_ref - out_ref_fp32).abs().max().item() + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + mask_desc = f"mask_mod={mask_name}" + if mask_name == "sliding_window" and window_size is not None: + mask_desc += f"(w={window_size})" + + print( + f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " + f"D={headdim}, M={tile_m}, N={tile_n}" + ) + print(" Reference implementation: FlexAttention") + print(f" Reference vs FP32: {ref_error:.2e}") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") + + # Debug: show some sample values if error is large + if cute_error > 1e-2: + print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") + print(f" DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}") + print(f" DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}") + max_diff_idx = (out_cute - out_ref_fp32).abs().argmax() + max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape) + print(f" DEBUG: Max diff at coords: {max_diff_coords}") + print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") + print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") + + # Use the same assertion logic as FlashAttention tests + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +def test_mask_mod_ima_partial_block(): + _run_mask_test( + seqlen_q=257, + seqlen_k=257, + nheads=1, + kv_mode="mha", + headdim=128, + dtype=torch.bfloat16, + mask_name="ima", + window_size=None, + window_left=None, + window_right=None, + tile_m=128, + tile_n=128, + use_block_sparsity=True, + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) +@pytest.mark.parametrize( + "mask_name", + ["block_diagonal", "mini_causal"], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) +def test_static_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, tile_m, tile_n +): + """Test static masks that don't require recompilation per seqlen pair. + + Known good masks: + - block_diagonal: Masks by 64-element diagonal blocks + - mini_causal: Local causal within 128-element tiles + """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=None, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + use_block_sparsity=use_block_sparsity, + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ("document", None), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) +def test_parameterized_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n +): + """Test parameterized masks that require recompilation per seqlen pair. + + Uses fewer seqlen combinations to reduce test time. + + Masks tested: + - causal, block_causal: Require offset = seqlen_k - seqlen_q + - sliding_window: Requires window size and offset parameters + - document: Slower to check + """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=window_size, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + use_block_sparsity=use_block_sparsity, + ) + + +def test_sm100_block_sparse_sink_all_masked(): + """Block-sparse regression for the sink path""" + if torch.cuda.get_device_capability()[0] != 10: + pytest.skip("SM100-only test") + device = "cuda" + dtype = torch.bfloat16 + batch_size = 1 + seqlen_q = 256 + seqlen_k = 128 + nheads = 8 + headdim = 128 + q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device) + zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device) + zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device) + sparse = BlockSparseTensorsTorch( + mask_block_cnt=zero_cnt, + mask_block_idx=zero_idx, + full_block_cnt=zero_cnt, + full_block_idx=zero_idx, + ) + softmax_scale = 1.0 / math.sqrt(headdim) + _, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=learnable_sink, + m_block_size=128, + n_block_size=128, + num_threads=384, + pack_gqa=False, + block_sparse_tensors=sparse, + return_lse=True, + ) + # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit. + expected = learnable_sink.float()[None, :, None].expand_as(lse) + assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) + \ No newline at end of file diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py new file mode 100644 index 00000000000..147e5519394 --- /dev/null +++ b/tests/cute/test_score_mod.py @@ -0,0 +1,499 @@ +import pytest +import torch +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import math as mlir_math +import operator +from torch.nn.attention.flex_attention import flex_attention +from flash_attn.cute.interface import _flash_attn_fwd + + +@cute.jit +def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = tSrS_ssa + tSrS_ssa = tmp0 + return tSrS_ssa + + +@cute.jit +def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = operator.ge(tmp0, tmp1) + tmp3 = tSrS_ssa + tmp4 = cute.where(tmp2, tmp3, cute.full_like(tmp3, float("-inf"))) + tSrS_ssa = tmp4 + return tSrS_ssa + + +@cute.jit +def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = tSrS_ssa + tmp1 = q_idx + tmp2 = kv_idx + tmp3 = tmp1 - tmp2 + tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) + tmp5 = tmp4.to(cutlass.Float32) + tmp6 = tmp0 + tmp5 + tSrS_ssa = tmp6 + return tSrS_ssa + + +@cute.jit +def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = tSrS_ssa + tmp1 = q_idx + tmp2 = kv_idx + tmp3 = tmp1 - tmp2 + tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) + tmp5 = tmp4 * cute.full_like(tmp4, 2) + tmp6 = tmp5.to(cutlass.Float32) + tmp7 = tmp0 + tmp6 + tSrS_ssa = tmp7 + return tSrS_ssa + + +@cute.jit +def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = tSrS_ssa + tmp1 = tmp0 * cute.full_like(tmp0, 2) + tSrS_ssa = tmp1 + return tSrS_ssa + + +@cute.jit +def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = tSrS_ssa + tmp1 = tmp0.to(cutlass.Float32) + tmp2 = h_idx + tmp3 = tmp2 + cute.full_like(tmp2, 1) + tmp4 = tmp3 * cute.full_like(tmp3, -8) + tmp5 = tmp4.to(cutlass.Float32) + tmp6 = tmp5 * cute.full_like(tmp5, 0.125) + tmp7 = tmp6 * cute.full_like(tmp6, 0.6931471805599453) + tmp8 = cute.math.exp2(tmp7 * 1.4426950408889634) + tmp9 = q_idx + tmp10 = kv_idx + tmp11 = tmp9 - tmp10 + tmp12 = cute.TensorSSA(mlir_math.absi(tmp11), tmp11.shape, tmp11.dtype) + tmp13 = tmp12.to(cutlass.Float32) + tmp14 = tmp8 * tmp13 + tmp15 = tmp1 - tmp14 + tSrS_ssa = tmp15 + return tSrS_ssa + + +@cute.jit +def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tmp0 - tmp1 + tmp3 = cute.TensorSSA(mlir_math.absi(tmp2), tmp2.shape, tmp2.dtype) + tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256)) + tmp5 = tSrS_ssa + tmp6 = cute.where(tmp4, tmp5, cute.full_like(tmp5, float("-inf"))) + tSrS_ssa = tmp6 + return tSrS_ssa + + +@cute.jit +def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tSrS_ssa + tmp3 = cute.where( + operator.eq(tmp0 // 64, tmp1 // 64), tmp2, cute.full_like(tmp2, float("-inf")) + ) + tSrS_ssa = tmp3 + return tSrS_ssa + + +@cute.jit +def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tmp0 - tmp1 + tmp3 = operator.ge(tmp2, cute.full_like(tmp2, 0)) + tmp4 = tSrS_ssa + tmp5 = cute.where(tmp3, tmp4, cute.full_like(tmp4, float("-inf"))) + tSrS_ssa = tmp5 + return tSrS_ssa + + +@cute.jit +def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + batch_bias = aux_tensors[0] + + # Detect dtype from buffer element type + dtype = batch_bias.element_type + + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = batch_bias[b_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + bias_val + + +@cute.jit +def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + + # Detect dtype from buffer element type + dtype = head_bias.element_type + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_frag[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx) + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_frag[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + + +# Eager reference functions for comparison +def identity_eager(score, b, h, q_idx, kv_idx): + return score + + +def causal_mask_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx >= kv_idx, score, float("-inf")) + + +def relative_bias_eager(score, b, h, q_idx, kv_idx): + return score + torch.abs(q_idx - kv_idx) + + +def relative_bias_v2_eager(score, b, h, q_idx, kv_idx): + return score + 2 * torch.abs(q_idx - kv_idx) + + +def times_two_eager(score, b, h, q_idx, kv_idx): + return score * 2 + + +def alibi_bias_eager(score, b, h, q_idx, kv_idx): + slope = 2 ** (-8 * (h + 1) / 8) + return score - slope * torch.abs(q_idx - kv_idx) + + +def sliding_window_eager(score, b, h, q_idx, kv_idx): + return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) + + +def block_diagonal_eager(score, b, h, q_idx, kv_idx): + q_block = q_idx // 64 + kv_block = kv_idx // 64 + return torch.where(q_block == kv_block, score, float("-inf")) + + +def causal_mask_v2_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) + + +def batch_bias(bias_tensor): + """Per-batch bias (tests batch indexing).""" + + def batch_bias_mod(score, b, h, q_idx, kv_idx): + return score + bias_tensor[b] + + return batch_bias_mod + + +def dual_buffer_bias(head_bias, pos_scale): + """Dual buffer loading (tests loading from 2 separate tensors).""" + + def dual_buffer_mod(score, b, h, q_idx, kv_idx): + head_component = head_bias[h] + pos_component = pos_scale[q_idx] + return score + pos_component + head_component + + return dual_buffer_mod + + +# Test pairs: (cute_jit_function, eager_reference_function) +TEST_PAIRS = [ + (score_mod_1, None), + (score_mod_2, causal_mask_eager), + (score_mod_3, relative_bias_eager), + (score_mod_4, relative_bias_v2_eager), + (score_mod_5, times_two_eager), + (score_mod_6, alibi_bias_eager), + (score_mod_7, sliding_window_eager), + (score_mod_8, block_diagonal_eager), + (score_mod_9, causal_mask_v2_eager), +] + +# Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory) +TEST_PAIRS_WITH_AUX_TENSORS = [ + (score_mod_10, batch_bias), + (score_mod_11, dual_buffer_bias), +] + + +def create_tensors( + batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 +): + q = torch.randn(batch_size, num_heads, seqlen_q, dim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + return q, k, v + + +def run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False +) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map( + lambda x: x.transpose(1, 2), (q, k, v) + ) + out = torch.empty_like(q_transposed) + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + return_lse=True, + score_mod=cute_score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out.transpose(1, 2) + + +def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: + if dtype is not None: + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) + return flex_attention( + q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] + ) + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +def test_cute_vs_flex_attention( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod = score_mod_pair + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +def test_cute_vs_flex_attention_with_aux_tensors( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod_factory = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + batch_size=batch_size, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, + num_heads=num_q_heads, + dtype=dtype, + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [buffer] + eager_score_mod = eager_score_mod_factory(buffer) + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) + assert head_bias.shape == (num_q_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.xfail( + raises=NotImplementedError, reason="Varlen with score_mod not yet supported" +) +def test_varlen_with_score_mod(): + """Test that varlen (variable length sequences) works with score_mod. + + For varlen, tokens from different sequences should not attend to each other. + Without proper index mapping, the causal mask will be applied to the global + indices instead of per-sequence logical indices. + """ + torch.random.manual_seed(42) + + seqlens = [64, 56, 128] + total_seq = sum(seqlens) + num_heads = 4 + dtype = torch.bfloat16 + + cu_seqlens = torch.tensor( + [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + + out_cute = torch.empty_like(q) + + _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + return_lse=True, + score_mod=score_mod_2, + out=out_cute, + lse=None, + ) + + assert not torch.isnan(out_cute).any(), "Output contains NaN values" + assert torch.isfinite(out_cute).all(), "Output contains infinite values" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c3..d5590fcfc82 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1399,7 +1399,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + assert (q.grad - q_ref.grad).abs().max().item() <= 7 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (